Skip to content

Commit

Permalink
[Enhancement] Support overflow mode for decimal type (#30419)
Browse files Browse the repository at this point in the history
Signed-off-by: liuyehcf <[email protected]>
(cherry picked from commit 228c120)

# Conflicts:
#	gensrc/thrift/InternalService.thrift
  • Loading branch information
liuyehcf authored and mergify[bot] committed Sep 11, 2023
1 parent 158f7bc commit 7923802
Show file tree
Hide file tree
Showing 18 changed files with 1,014 additions and 564 deletions.
75 changes: 50 additions & 25 deletions be/src/exprs/arithmetic_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "exprs/binary_function.h"
#include "exprs/decimal_binary_function.h"
#include "exprs/decimal_cast_expr.h"
#include "exprs/overflow.h"
#include "exprs/unary_function.h"
#include "runtime/decimalv3.h"
#include "util/pred_guard.h"
Expand Down Expand Up @@ -67,25 +68,25 @@ class VectorizedArithmeticExpr final : public Expr {
if (lhs_pt == TYPE_DECIMAL64 && rhs_pt == TYPE_DECIMAL64 && Type == TYPE_DECIMAL128) {
ASSIGN_OR_RETURN(auto l, _children[0]->get_child(0)->evaluate_checked(context, chunk));
ASSIGN_OR_RETURN(auto r, _children[1]->get_child(0)->evaluate_checked(context, chunk));
return VectorizedStrictDecimalBinaryFunction<MulOp64x64_128, false>::template evaluate<
return VectorizedStrictDecimalBinaryFunction<MulOp64x64_128, OverflowMode::IGNORE>::template evaluate<
TYPE_DECIMAL64, TYPE_DECIMAL64, Type>(l, r);
}
if (lhs_pt == TYPE_DECIMAL32 && rhs_pt == TYPE_DECIMAL64 && Type == TYPE_DECIMAL128) {
ASSIGN_OR_RETURN(auto l, _children[0]->get_child(0)->evaluate_checked(context, chunk));
ASSIGN_OR_RETURN(auto r, _children[1]->get_child(0)->evaluate_checked(context, chunk));
return VectorizedStrictDecimalBinaryFunction<MulOp32x64_128, false>::template evaluate<
return VectorizedStrictDecimalBinaryFunction<MulOp32x64_128, OverflowMode::IGNORE>::template evaluate<
TYPE_DECIMAL32, TYPE_DECIMAL64, Type>(l, r);
}
if (lhs_pt == TYPE_DECIMAL64 && rhs_pt == TYPE_DECIMAL32 && Type == TYPE_DECIMAL128) {
ASSIGN_OR_RETURN(auto l, _children[0]->get_child(0)->evaluate_checked(context, chunk));
ASSIGN_OR_RETURN(auto r, _children[1]->get_child(0)->evaluate_checked(context, chunk));
return VectorizedStrictDecimalBinaryFunction<MulOp32x64_128, false>::template evaluate<
return VectorizedStrictDecimalBinaryFunction<MulOp32x64_128, OverflowMode::IGNORE>::template evaluate<
TYPE_DECIMAL32, TYPE_DECIMAL64, Type>(r, l);
}
if (lhs_pt == TYPE_DECIMAL32 && rhs_pt == TYPE_DECIMAL32 && Type == TYPE_DECIMAL128) {
ASSIGN_OR_RETURN(auto l, _children[0]->get_child(0)->evaluate_checked(context, chunk));
ASSIGN_OR_RETURN(auto r, _children[1]->get_child(0)->evaluate_checked(context, chunk));
return VectorizedStrictDecimalBinaryFunction<MulOp32x32_128, false>::template evaluate<
return VectorizedStrictDecimalBinaryFunction<MulOp32x32_128, OverflowMode::IGNORE>::template evaluate<
TYPE_DECIMAL32, TYPE_DECIMAL32, Type>(r, l);
}
}
Expand All @@ -105,7 +106,13 @@ class VectorizedArithmeticExpr final : public Expr {
ASSIGN_OR_RETURN(auto r, _children[1]->evaluate_checked(context, ptr));
if constexpr (lt_is_decimal<Type>) {
// Enable overflow checking in decimal arithmetic
return VectorizedStrictDecimalBinaryFunction<OP, true>::template evaluate<Type>(l, r);
if (context != nullptr && context->error_if_overflow()) {
return VectorizedStrictDecimalBinaryFunction<OP, OverflowMode::REPORT_ERROR>::template evaluate<Type>(
l, r);
} else {
return VectorizedStrictDecimalBinaryFunction<OP, OverflowMode::OUTPUT_NULL>::template evaluate<Type>(l,
r);
}
} else {
using ArithmeticOp = ArithmeticBinaryOperator<OP, Type>;
return VectorizedStrictBinaryFunction<ArithmeticOp>::template evaluate<Type>(l, r);
Expand All @@ -128,23 +135,31 @@ class VectorizedDivArithmeticExpr final : public Expr {
DEFINE_CLASS_CONSTRUCTOR(VectorizedDivArithmeticExpr);
StatusOr<ColumnPtr> evaluate_checked(ExprContext* context, Chunk* ptr) override {
if constexpr (is_intdiv_op<Op> && lt_is_bigint<Type>) {
using CastFunction = VectorizedUnaryFunction<DecimalTo<true>>;
switch (_children[0]->type().type) {
case TYPE_DECIMAL32: {
ASSIGN_OR_RETURN(auto column, evaluate_internal<TYPE_DECIMAL32>(context, ptr));
return CastFunction::evaluate<TYPE_DECIMAL32, LogicalType::TYPE_BIGINT>(column);
}
case TYPE_DECIMAL64: {
ASSIGN_OR_RETURN(auto column, evaluate_internal<TYPE_DECIMAL64>(context, ptr));
return CastFunction::evaluate<TYPE_DECIMAL64, LogicalType::TYPE_BIGINT>(column);
}
case TYPE_DECIMAL128: {
ASSIGN_OR_RETURN(auto column, evaluate_internal<TYPE_DECIMAL128>(context, ptr));
return CastFunction::evaluate<TYPE_DECIMAL128, LogicalType::TYPE_BIGINT>(column);
}
default:
return evaluate_internal<Type>(context, ptr);
#define EVALUATE_CHECKED_OVERFLOW(Mode) \
using CastFunction = VectorizedUnaryFunction<DecimalTo<Mode>>; \
switch (_children[0]->type().type) { \
case TYPE_DECIMAL32: { \
ASSIGN_OR_RETURN(auto column, evaluate_internal<TYPE_DECIMAL32>(context, ptr)); \
return CastFunction::evaluate<TYPE_DECIMAL32, LogicalType::TYPE_BIGINT>(column); \
} \
case TYPE_DECIMAL64: { \
ASSIGN_OR_RETURN(auto column, evaluate_internal<TYPE_DECIMAL64>(context, ptr)); \
return CastFunction::evaluate<TYPE_DECIMAL64, LogicalType::TYPE_BIGINT>(column); \
} \
case TYPE_DECIMAL128: { \
ASSIGN_OR_RETURN(auto column, evaluate_internal<TYPE_DECIMAL128>(context, ptr)); \
return CastFunction::evaluate<TYPE_DECIMAL128, LogicalType::TYPE_BIGINT>(column); \
} \
default: \
return evaluate_internal<Type>(context, ptr); \
}

if (context != nullptr && context->error_if_overflow()) {
EVALUATE_CHECKED_OVERFLOW(OverflowMode::REPORT_ERROR);
} else {
EVALUATE_CHECKED_OVERFLOW(OverflowMode::OUTPUT_NULL);
}
#undef EVALUATE_CHECKED_OVERFLOW
} else {
return evaluate_internal<Type>(context, ptr);
}
Expand All @@ -156,8 +171,13 @@ class VectorizedDivArithmeticExpr final : public Expr {
ASSIGN_OR_RETURN(auto l, _children[0]->evaluate_checked(context, ptr));
ASSIGN_OR_RETURN(auto r, _children[1]->evaluate_checked(context, ptr));
if constexpr (lt_is_decimal<LType>) {
using VectorizedDiv = VectorizedUnstrictDecimalBinaryFunction<LType, DivOp, true>;
return VectorizedDiv::template evaluate<LType>(l, r);
if (context != nullptr && context->error_if_overflow()) {
using VectorizedDiv = VectorizedUnstrictDecimalBinaryFunction<LType, DivOp, OverflowMode::REPORT_ERROR>;
return VectorizedDiv::template evaluate<LType>(l, r);
} else {
using VectorizedDiv = VectorizedUnstrictDecimalBinaryFunction<LType, DivOp, OverflowMode::OUTPUT_NULL>;
return VectorizedDiv::template evaluate<LType>(l, r);
}
} else {
using RightZeroCheck = ArithmeticRightZeroCheck<LType>;
using ArithmeticDiv = ArithmeticBinaryOperator<DivOp, LType>;
Expand All @@ -176,8 +196,13 @@ class VectorizedModArithmeticExpr final : public Expr {
ASSIGN_OR_RETURN(auto r, _children[1]->evaluate_checked(context, ptr));

if constexpr (lt_is_decimal<Type>) {
using VectorizedDiv = VectorizedUnstrictDecimalBinaryFunction<Type, ModOp, true>;
return VectorizedDiv::template evaluate<Type>(l, r);
if (context != nullptr && context->error_if_overflow()) {
using VectorizedDiv = VectorizedUnstrictDecimalBinaryFunction<Type, ModOp, OverflowMode::REPORT_ERROR>;
return VectorizedDiv::template evaluate<Type>(l, r);
} else {
using VectorizedDiv = VectorizedUnstrictDecimalBinaryFunction<Type, ModOp, OverflowMode::OUTPUT_NULL>;
return VectorizedDiv::template evaluate<Type>(l, r);
}
} else {
using RightZeroCheck = ArithmeticRightZeroCheck<Type>;
using ArithmeticMod = ArithmeticBinaryOperator<ModOp, Type>;
Expand Down
37 changes: 37 additions & 0 deletions be/src/exprs/arithmetic_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,43 @@ bool check_fpe_of_min_div_by_minus_one(LType lhs, RType rhs) {
}
}

template <typename Op>
std::string get_op_name() {
if constexpr (is_add_op<Op>) {
return "add";
} else if constexpr (is_sub_op<Op>) {
return "sub";
} else if constexpr (is_reverse_sub_op<Op>) {
return "reverse_sub";
} else if constexpr (is_reverse_mod_op<Op>) {
return "reverse_mod";
} else if constexpr (is_mul_op<Op>) {
return "mul";
} else if constexpr (is_div_op<Op>) {
return "div";
} else if constexpr (is_intdiv_op<Op>) {
return "intdiv";
} else if constexpr (is_mod_op<Op>) {
return "mod";
} else if constexpr (is_bitand_op<Op>) {
return "bitand";
} else if constexpr (is_bitor_op<Op>) {
return "bitor";
} else if constexpr (is_bitxor_op<Op>) {
return "bitxor";
} else if constexpr (is_bitnot_op<Op>) {
return "bitnot";
} else if constexpr (is_bit_shift_left_op<Op>) {
return "bit_shift_left";
} else if constexpr (is_bit_shift_right_op<Op>) {
return "bit_shift_right";
} else if constexpr (is_bit_shift_right_logical_op<Op>) {
return "bit_shift_right_logical";
} else {
return "unknown";
}
}

template <typename Op, LogicalType Type, typename = guard::Guard, typename = guard::Guard>
struct ArithmeticBinaryOperator {
template <typename LType, typename RType, typename ResultType>
Expand Down
45 changes: 37 additions & 8 deletions be/src/exprs/cast_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,20 +1083,43 @@ class VectorizedCastExpr final : public Expr {
// to double at first, then convert double to JSON
if constexpr (FromType == TYPE_JSON || ToType == TYPE_JSON) {
if constexpr (lt_is_decimal<FromType>) {
ColumnPtr double_column =
VectorizedUnaryFunction<DecimalTo<true>>::evaluate<FromType, TYPE_DOUBLE>(column);
ColumnPtr double_column;
if (context != nullptr && context->error_if_overflow()) {
double_column = VectorizedUnaryFunction<DecimalTo<OverflowMode::REPORT_ERROR>>::evaluate<
FromType, TYPE_DOUBLE>(column);
} else {
double_column = VectorizedUnaryFunction<DecimalTo<OverflowMode::OUTPUT_NULL>>::evaluate<
FromType, TYPE_DOUBLE>(column);
}
result_column = CastFn<TYPE_DOUBLE, TYPE_JSON, AllowThrowException>::cast_fn(double_column);
} else {
result_column = CastFn<FromType, ToType, AllowThrowException>::cast_fn(column);
}
} else if constexpr (lt_is_decimal<FromType> && lt_is_decimal<ToType>) {
return VectorizedUnaryFunction<DecimalToDecimal<true>>::evaluate<FromType, ToType>(
column, to_type.precision, to_type.scale);
if (context != nullptr && context->error_if_overflow()) {
return VectorizedUnaryFunction<DecimalToDecimal<OverflowMode::REPORT_ERROR>>::evaluate<FromType,
ToType>(
column, to_type.precision, to_type.scale);
} else {
return VectorizedUnaryFunction<DecimalToDecimal<OverflowMode::OUTPUT_NULL>>::evaluate<FromType, ToType>(
column, to_type.precision, to_type.scale);
}
} else if constexpr (lt_is_decimal<FromType>) {
return VectorizedUnaryFunction<DecimalTo<true>>::evaluate<FromType, ToType>(column);
if (context != nullptr && context->error_if_overflow()) {
return VectorizedUnaryFunction<DecimalTo<OverflowMode::REPORT_ERROR>>::evaluate<FromType, ToType>(
column);
} else {
return VectorizedUnaryFunction<DecimalTo<OverflowMode::OUTPUT_NULL>>::evaluate<FromType, ToType>(
column);
}
} else if constexpr (lt_is_decimal<ToType>) {
return VectorizedUnaryFunction<DecimalFrom<true>>::evaluate<FromType, ToType>(column, to_type.precision,
to_type.scale);
if (context != nullptr && context->error_if_overflow()) {
return VectorizedUnaryFunction<DecimalFrom<OverflowMode::REPORT_ERROR>>::evaluate<FromType, ToType>(
column, to_type.precision, to_type.scale);
} else {
return VectorizedUnaryFunction<DecimalFrom<OverflowMode::OUTPUT_NULL>>::evaluate<FromType, ToType>(
column, to_type.precision, to_type.scale);
}
} else {
result_column = CastFn<FromType, ToType, AllowThrowException>::cast_fn(column);
}
Expand Down Expand Up @@ -1257,7 +1280,13 @@ class VectorizedCastToStringExpr final : public Expr {
}

if constexpr (lt_is_decimal<Type>) {
return VectorizedUnaryFunction<DecimalTo<true>>::evaluate<Type, TYPE_VARCHAR>(column);
if (context != nullptr && context->error_if_overflow()) {
return VectorizedUnaryFunction<DecimalTo<OverflowMode::REPORT_ERROR>>::evaluate<Type, TYPE_VARCHAR>(
column);
} else {
return VectorizedUnaryFunction<DecimalTo<OverflowMode::OUTPUT_NULL>>::evaluate<Type, TYPE_VARCHAR>(
column);
}
}

// must be: TYPE_FLOAT, TYPE_DOUBLE, TYPE_CHAR, TYPE_VARCHAR...
Expand Down
Loading

0 comments on commit 7923802

Please sign in to comment.