#pragma once #include #include #include #include #include #include #include #include #include #include /// TODO Core should not depend on Functions namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int DECIMAL_OVERFLOW; } inline bool allowDecimalComparison(const DataTypePtr & left_type, const DataTypePtr & right_type) { if (isColumnedAsDecimal(left_type)) { if (isColumnedAsDecimal(right_type) || isNotDecimalButComparableToDecimal(right_type)) return true; } else if (isNotDecimalButComparableToDecimal(left_type) && isColumnedAsDecimal(right_type)) { return true; } return false; } template struct ConstructDecInt; template <> struct ConstructDecInt<1> { using Type = Int32; }; template <> struct ConstructDecInt<2> { using Type = Int32; }; template <> struct ConstructDecInt<4> { using Type = Int32; }; template <> struct ConstructDecInt<8> { using Type = Int64; }; template <> struct ConstructDecInt<16> { using Type = Int128; }; template <> struct ConstructDecInt<32> { using Type = Int256; }; template struct DecCompareInt { using Type = typename ConstructDecInt<(!is_decimal || sizeof(T) > sizeof(U)) ? sizeof(T) : sizeof(U)>::Type; using TypeA = Type; using TypeB = Type; }; template typename Operation, bool _check_overflow = true, bool _actual = is_decimal || is_decimal> class DecimalComparison { public: using CompareInt = typename DecCompareInt::Type; using Op = Operation; using ColVecA = ColumnVectorOrDecimal; using ColVecB = ColumnVectorOrDecimal; using ArrayA = typename ColVecA::Container; using ArrayB = typename ColVecB::Container; static ColumnPtr apply(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right) { if constexpr (_actual) { ColumnPtr c_res; Shift shift = getScales(col_left.type, col_right.type); return applyWithScale(col_left.column, col_right.column, shift); } else return nullptr; } static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b) { static const UInt32 max_scale = DecimalUtils::max_precision; if (scale_a > max_scale || scale_b > max_scale) throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Bad scale of decimal field"); Shift shift; if (scale_a < scale_b) shift.a = static_cast(DecimalUtils::scaleMultiplier(scale_b - scale_a)); if (scale_a > scale_b) shift.b = static_cast(DecimalUtils::scaleMultiplier(scale_a - scale_b)); return applyWithScale(a, b, shift); } private: struct Shift { CompareInt a = 1; CompareInt b = 1; bool none() const { return a == 1 && b == 1; } bool left() const { return a != 1; } bool right() const { return b != 1; } }; template static auto applyWithScale(T a, U b, const Shift & shift) { if (shift.left()) return apply(a, b, shift.a); if (shift.right()) return apply(a, b, shift.b); return apply(a, b, 1); } template requires is_decimal && is_decimal static Shift getScales(const DataTypePtr & left_type, const DataTypePtr & right_type) { const DataTypeDecimalBase * decimal0 = checkDecimalBase(*left_type); const DataTypeDecimalBase * decimal1 = checkDecimalBase(*right_type); Shift shift; if (decimal0 && decimal1) { auto result_type = DecimalUtils::binaryOpResult(*decimal0, *decimal1); shift.a = static_cast(result_type.scaleFactorFor(decimal0->getTrait(), false).value); shift.b = static_cast(result_type.scaleFactorFor(decimal1->getTrait(), false).value); } else if (decimal0) shift.b = static_cast(decimal0->getScaleMultiplier().value); else if (decimal1) shift.a = static_cast(decimal1->getScaleMultiplier().value); return shift; } template requires is_decimal && (!is_decimal) static Shift getScales(const DataTypePtr & left_type, const DataTypePtr &) { Shift shift; const DataTypeDecimalBase * decimal0 = checkDecimalBase(*left_type); if (decimal0) shift.b = static_cast(decimal0->getScaleMultiplier().value); return shift; } template requires (!is_decimal) && is_decimal static Shift getScales(const DataTypePtr &, const DataTypePtr & right_type) { Shift shift; const DataTypeDecimalBase * decimal1 = checkDecimalBase(*right_type); if (decimal1) shift.a = static_cast(decimal1->getScaleMultiplier().value); return shift; } template static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale) { auto c_res = ColumnUInt8::create(); if constexpr (_actual) { bool c0_is_const = isColumnConst(*c0); bool c1_is_const = isColumnConst(*c1); if (c0_is_const && c1_is_const) { const ColumnConst & c0_const = checkAndGetColumnConst(*c0); const ColumnConst & c1_const = checkAndGetColumnConst(*c1); A a = c0_const.template getValue(); B b = c1_const.template getValue(); UInt8 res = apply(a, b, scale); return DataTypeUInt8().createColumnConst(c0->size(), toField(res)); } ColumnUInt8::Container & vec_res = c_res->getData(); vec_res.resize(c0->size()); if (c0_is_const) { const ColumnConst & c0_const = checkAndGetColumnConst(*c0); A a = c0_const.template getValue(); if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) constantVector(a, c1_vec->getData(), vec_res, scale); else throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); } else if (c1_is_const) { const ColumnConst & c1_const = checkAndGetColumnConst(*c1); B b = c1_const.template getValue(); if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) vectorConstant(c0_vec->getData(), b, vec_res, scale); else throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); } else { if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) { if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) vectorVector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); else throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); } else throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); } } return c_res; } template static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]]) { CompareInt x; if constexpr (is_decimal) x = a.value; else x = a; CompareInt y; if constexpr (is_decimal) y = b.value; else y = static_cast(b); if constexpr (_check_overflow) { bool overflow = false; if constexpr (sizeof(A) > sizeof(CompareInt)) overflow |= (static_cast(x) != a); if constexpr (sizeof(B) > sizeof(CompareInt)) overflow |= (static_cast(y) != b); if constexpr (is_unsigned_v) overflow |= (x < 0); if constexpr (is_unsigned_v) overflow |= (y < 0); if constexpr (scale_left) overflow |= common::mulOverflow(x, scale, x); if constexpr (scale_right) overflow |= common::mulOverflow(y, scale, y); if (overflow) throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Can't compare decimal number due to overflow"); } else { if constexpr (scale_left) x = common::mulIgnoreOverflow(x, scale); if constexpr (scale_right) y = common::mulIgnoreOverflow(y, scale); } return Op::apply(x, y); } template static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray & c, CompareInt scale) { size_t size = a.size(); const A * a_pos = a.data(); const B * b_pos = b.data(); UInt8 * c_pos = c.data(); const A * a_end = a_pos + size; while (a_pos < a_end) { *c_pos = apply(*a_pos, *b_pos, scale); ++a_pos; ++b_pos; ++c_pos; } } template static void NO_INLINE vectorConstant(const ArrayA & a, B b, PaddedPODArray & c, CompareInt scale) { size_t size = a.size(); const A * a_pos = a.data(); UInt8 * c_pos = c.data(); const A * a_end = a_pos + size; while (a_pos < a_end) { *c_pos = apply(*a_pos, b, scale); ++a_pos; ++c_pos; } } template static void NO_INLINE constantVector(A a, const ArrayB & b, PaddedPODArray & c, CompareInt scale) { size_t size = b.size(); const B * b_pos = b.data(); UInt8 * c_pos = c.data(); const B * b_end = b_pos + size; while (b_pos < b_end) { *c_pos = apply(a, *b_pos, scale); ++b_pos; ++c_pos; } } }; }