diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index b43eca542f3..95cf8c47d6d 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -799,8 +799,10 @@ Result RoundToMultiple(const Datum& arg, RoundToMultipleOptions options, SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked") SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked") SCALAR_ARITHMETIC_BINARY(Logb, "logb", "logb_checked") +SCALAR_ARITHMETIC_BINARY(Mod, "mod", "mod_checked") SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked") SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked") +SCALAR_ARITHMETIC_BINARY(Remainder, "remainder", "remainder_checked") SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked") SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked") SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked") diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 8b341e865a1..8fafec4ce01 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -671,6 +671,40 @@ Result Divide(const Datum& left, const Datum& right, ArithmeticOptions options = ArithmeticOptions(), ExecContext* ctx = NULLPTR); +/// \brief Compute the remainder (truncated division) of two values. +/// Array values must be the same length. If either argument is null the result +/// will be null. For integer types, if there is a zero divisor, an error will be +/// raised. +/// +/// The result has the same sign as the dividend (C/C++ semantics). +/// +/// \param[in] left the dividend +/// \param[in] right the divisor +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise remainder +ARROW_EXPORT +Result Remainder(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the modulo (floored division) of two values. +/// Array values must be the same length. If either argument is null the result +/// will be null. For integer types, if there is a zero divisor, an error will be +/// raised. +/// +/// The result has the same sign as the divisor (Python semantics). +/// +/// \param[in] left the dividend +/// \param[in] right the divisor +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise modulo +ARROW_EXPORT +Result Mod(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + /// \brief Negate values. /// /// If argument is null the result will be null. diff --git a/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h index b4840061ae7..55d41de3af8 100644 --- a/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h +++ b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h @@ -34,6 +34,7 @@ namespace arrow { using internal::AddWithOverflow; using internal::DivideWithOverflow; +using internal::ModuloWithOverflow; using internal::MultiplyWithOverflow; using internal::NegateWithOverflow; using internal::SubtractWithOverflow; @@ -468,6 +469,172 @@ struct FloatingDivideChecked { // TODO: Add decimal }; +// Remainder (truncated): result has same sign as dividend (C/C++ semantics) +struct Remainder { + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::fmod(left, right); + } + + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result; + if (ARROW_PREDICT_FALSE(ModuloWithOverflow(left, right, &result))) { + if (right == 0) { + *st = Status::Invalid("divide by zero"); + } else { + // INT_MIN % -1 overflow case, result is 0 + result = 0; + } + } + return result; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + if (right == Arg1()) { + *st = Status::Invalid("divide by zero"); + return T(); + } + return left % right; + } +}; + +struct RemainderChecked { + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(right == 0)) { + *st = Status::Invalid("divide by zero"); + return 0; + } + return std::fmod(left, right); + } + + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + T result; + if (ARROW_PREDICT_FALSE(ModuloWithOverflow(left, right, &result))) { + if (right == 0) { + *st = Status::Invalid("divide by zero"); + } else { + *st = Status::Invalid("overflow"); + } + } + return result; + } + + template + static enable_if_decimal_value Call(KernelContext* ctx, Arg0 left, Arg1 right, + Status* st) { + return Remainder::Call(ctx, left, right, st); + } +}; + +// Helper: Convert truncated remainder to floored modulo for signed types. +// Floored modulo has the same sign as the divisor (Python semantics). +template +T AdjustRemainderToFloored(T rem, T right) { + if constexpr (std::is_signed_v) { + if ((rem > 0 && right < 0) || (rem < 0 && right > 0)) { + rem += right; + } + } + return rem; +} + +// Mod (floored): result has same sign as divisor (Python semantics) +struct Mod { + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + T rem = std::fmod(left, right); + if (rem == 0) { + // Preserve the sign based on divisor for zero results + return std::copysign(rem, right); + } + return AdjustRemainderToFloored(rem, right); + } + + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result; + if (ARROW_PREDICT_FALSE(ModuloWithOverflow(left, right, &result))) { + if (right == 0) { + *st = Status::Invalid("divide by zero"); + } else { + // INT_MIN % -1 overflow case, result is 0 + result = 0; + } + return result; + } + return AdjustRemainderToFloored(result, right); + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static const T kZero{}; + if (right == kZero) { + *st = Status::Invalid("divide by zero"); + return T(); + } + T rem = left % right; + // Convert truncated to floored: adjust if signs differ + if ((rem > kZero && right < kZero) || (rem < kZero && right > kZero)) { + rem = rem + right; + } + return rem; + } +}; + +struct ModChecked { + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(right == 0)) { + *st = Status::Invalid("divide by zero"); + return 0; + } + T rem = std::fmod(left, right); + if (rem == 0) { + // Preserve the sign based on divisor for zero results + return std::copysign(rem, right); + } + return AdjustRemainderToFloored(rem, right); + } + + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + T result; + if (ARROW_PREDICT_FALSE(ModuloWithOverflow(left, right, &result))) { + if (right == 0) { + *st = Status::Invalid("divide by zero"); + } else { + *st = Status::Invalid("overflow"); + } + return result; + } + return AdjustRemainderToFloored(result, right); + } + + template + static enable_if_decimal_value Call(KernelContext* ctx, Arg0 left, Arg1 right, + Status* st) { + return Mod::Call(ctx, left, right, st); + } +}; + struct Negate { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 03c9422809b..0eec8258478 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -663,7 +663,7 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) { OutputType out_type(null()); std::shared_ptr constraint = nullptr; const std::string op = name.substr(0, name.find("_")); - if (op == "add" || op == "subtract") { + if (op == "add" || op == "subtract" || op == "remainder" || op == "mod") { out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput); constraint = DecimalsHaveSameScale(); } else if (op == "multiply") { @@ -776,7 +776,7 @@ struct ArithmeticFunction : ScalarFunction { // "add_checked" -> "add" const auto func_name = name(); const std::string op = func_name.substr(0, func_name.find("_")); - if (op == "add" || op == "subtract") { + if (op == "add" || op == "subtract" || op == "remainder" || op == "mod") { return CastBinaryDecimalArgs(DecimalPromotion::kAdd, types); } else if (op == "multiply") { return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, types); @@ -1165,6 +1165,40 @@ const FunctionDoc div_checked_doc{ "integer overflow is encountered."), {"dividend", "divisor"}}; +const FunctionDoc remainder_doc{ + "Compute the remainder after integer division (truncated)", + ("Returns the remainder after dividing the dividend by the divisor.\n" + "The result has the same sign as the dividend (truncated division).\n" + "This is equivalent to the C/C++ '%' operator.\n" + "Integer division by zero returns an error."), + {"dividend", "divisor"}}; + +const FunctionDoc remainder_checked_doc{ + "Compute the remainder after integer division (truncated)", + ("Returns the remainder after dividing the dividend by the divisor.\n" + "The result has the same sign as the dividend (truncated division).\n" + "This is equivalent to the C/C++ '%' operator.\n" + "An error is returned when trying to divide by zero, or when\n" + "integer overflow is encountered."), + {"dividend", "divisor"}}; + +const FunctionDoc mod_doc{ + "Compute the modulo (floored)", + ("Returns the modulo after floored division.\n" + "The result has the same sign as the divisor (floored division).\n" + "This is equivalent to Python's '%' operator.\n" + "Integer division by zero returns an error."), + {"dividend", "divisor"}}; + +const FunctionDoc mod_checked_doc{ + "Compute the modulo (floored)", + ("Returns the modulo after floored division.\n" + "The result has the same sign as the divisor (floored division).\n" + "This is equivalent to Python's '%' operator.\n" + "An error is returned when trying to divide by zero, or when\n" + "integer overflow is encountered."), + {"dividend", "divisor"}}; + const FunctionDoc negate_doc{"Negate the argument element-wise", ("Results will wrap around on integer overflow.\n" "Use function \"negate_checked\" if you want overflow\n" @@ -1708,6 +1742,28 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(divide_checked))); + // ---------------------------------------------------------------------- + auto remainder = MakeArithmeticFunctionNotNull("remainder", remainder_doc); + AddDecimalBinaryKernels("remainder", remainder.get()); + DCHECK_OK(registry->AddFunction(std::move(remainder))); + + // ---------------------------------------------------------------------- + auto remainder_checked = MakeArithmeticFunctionNotNull( + "remainder_checked", remainder_checked_doc); + AddDecimalBinaryKernels("remainder_checked", remainder_checked.get()); + DCHECK_OK(registry->AddFunction(std::move(remainder_checked))); + + // ---------------------------------------------------------------------- + auto mod = MakeArithmeticFunctionNotNull("mod", mod_doc); + AddDecimalBinaryKernels("mod", mod.get()); + DCHECK_OK(registry->AddFunction(std::move(mod))); + + // ---------------------------------------------------------------------- + auto mod_checked = + MakeArithmeticFunctionNotNull("mod_checked", mod_checked_doc); + AddDecimalBinaryKernels("mod_checked", mod_checked.get()); + DCHECK_OK(registry->AddFunction(std::move(mod_checked))); + // ---------------------------------------------------------------------- auto negate = MakeUnaryArithmeticFunction("negate", negate_doc); AddDecimalUnaryKernels(negate.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 9367ad2c89d..f38b86a0e2b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -938,6 +938,173 @@ TYPED_TEST(TestBinaryArithmeticSigned, DivideOverflowRaises) { this->AssertBinop(Divide, MakeArray(min), MakeArray(-1), "[0]"); } +// ============== REMAINDER (Truncated) Tests ============== + +TYPED_TEST(TestBinaryArithmeticIntegral, Remainder) { + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + // Empty arrays + this->AssertBinop(Remainder, "[]", "[]", "[]"); + // Basic positive cases + this->AssertBinop(Remainder, "[7, 10, 20]", "[3, 4, 7]", "[1, 2, 6]"); + // Array with nulls + this->AssertBinop(Remainder, "[null, 10, 30, null, 20]", "[1, 4, 2, 5, 10]", + "[null, 2, 0, null, 0]"); + // Scalar % Array + this->AssertBinop(Remainder, 33, "[null, 1, 3, null, 2]", "[null, 0, 0, null, 1]"); + // Array % Scalar + this->AssertBinop(Remainder, "[null, 10, 30, null, 2]", 3, "[null, 1, 0, null, 2]"); + // Scalar % Scalar + this->AssertBinop(Remainder, 16, 7, 2); + } +} + +TYPED_TEST(TestBinaryArithmeticSigned, Remainder) { + // Truncated semantics: sign follows dividend + this->AssertBinop(Remainder, "[7]", "[3]", "[1]"); + this->AssertBinop(Remainder, "[-7]", "[3]", "[-1]"); + this->AssertBinop(Remainder, "[7]", "[-3]", "[1]"); + this->AssertBinop(Remainder, "[-7]", "[-3]", "[-1]"); + // Mixed array + this->AssertBinop(Remainder, "[-3, 2, -7, 10]", "[1, 1, 2, 3]", "[0, 0, -1, 1]"); +} + +TYPED_TEST(TestBinaryArithmeticUnsigned, Remainder) { + this->AssertBinop(Remainder, "[7, 100, 255]", "[3, 30, 16]", "[1, 10, 15]"); +} + +TYPED_TEST(TestBinaryArithmeticSigned, RemainderOverflow) { + using CType = typename TestFixture::CType; + auto min = std::numeric_limits::lowest(); + + // Unchecked: returns 0 (the mathematically correct result) + this->SetOverflowCheck(false); + this->AssertBinop(Remainder, MakeArray(min), MakeArray(CType(-1)), "[0]"); + + // Checked: raises overflow error + this->SetOverflowCheck(true); + this->AssertBinopRaises(Remainder, MakeArray(min), MakeArray(CType(-1)), "overflow"); +} + +TYPED_TEST(TestBinaryArithmeticIntegral, RemainderByZero) { + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + this->AssertBinopRaises(Remainder, "[3, 2, 6]", "[1, 1, 0]", "divide by zero"); + } +} + +TYPED_TEST(TestBinaryArithmeticFloating, Remainder) { + SKIP_IF_HALF_FLOAT(); + + this->SetNansEqual(true); + + // Basic cases + this->AssertBinop(Remainder, "[7.5, 10.0]", "[2.5, 3.0]", "[0.0, 1.0]"); + // Negative numbers - truncated semantics: sign follows dividend + this->AssertBinop(Remainder, "[-7.5]", "[2.5]", "[-0.0]"); + this->AssertBinop(Remainder, "[7.5]", "[-2.5]", "[0.0]"); + this->AssertBinop(Remainder, "[-7.5]", "[-2.5]", "[-0.0]"); + + // Division by zero returns NaN (unchecked) + this->SetOverflowCheck(false); + this->AssertBinop(Remainder, "[1.0]", "[0.0]", "[NaN]"); + + // Division by zero raises error (checked) + this->SetOverflowCheck(true); + this->AssertBinopRaises(Remainder, "[1.0]", "[0.0]", "divide by zero"); + + // Infinity edge cases (unchecked) + this->SetOverflowCheck(false); + this->AssertBinop(Remainder, "[Inf]", "[2.0]", "[NaN]"); + this->AssertBinop(Remainder, "[-Inf]", "[2.0]", "[NaN]"); + this->AssertBinop(Remainder, "[2.0]", "[Inf]", "[2.0]"); + this->AssertBinop(Remainder, "[2.0]", "[-Inf]", "[2.0]"); + this->AssertBinop(Remainder, "[Inf]", "[Inf]", "[NaN]"); +} + +// ============== MOD (Floored) Tests ============== + +TYPED_TEST(TestBinaryArithmeticIntegral, Mod) { + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + // Empty arrays + this->AssertBinop(Mod, "[]", "[]", "[]"); + // Basic positive cases (same as remainder for positive numbers) + this->AssertBinop(Mod, "[7, 10, 20]", "[3, 4, 7]", "[1, 2, 6]"); + // Array with nulls + this->AssertBinop(Mod, "[null, 10, 30, null, 20]", "[1, 4, 2, 5, 10]", + "[null, 2, 0, null, 0]"); + // Scalar % Array + this->AssertBinop(Mod, 33, "[null, 1, 3, null, 2]", "[null, 0, 0, null, 1]"); + // Array % Scalar + this->AssertBinop(Mod, "[null, 10, 30, null, 2]", 3, "[null, 1, 0, null, 2]"); + } +} + +TYPED_TEST(TestBinaryArithmeticSigned, Mod) { + // Floored semantics: sign follows divisor + this->AssertBinop(Mod, "[7]", "[3]", "[1]"); + this->AssertBinop(Mod, "[-7]", "[3]", "[2]"); + this->AssertBinop(Mod, "[7]", "[-3]", "[-2]"); + this->AssertBinop(Mod, "[-7]", "[-3]", "[-1]"); + // Edge case: -1 mod positive + this->AssertBinop(Mod, "[-1]", "[3]", "[2]"); +} + +TYPED_TEST(TestBinaryArithmeticUnsigned, Mod) { + // Same as remainder for unsigned (no negative numbers) + this->AssertBinop(Mod, "[7, 100, 255]", "[3, 30, 16]", "[1, 10, 15]"); +} + +TYPED_TEST(TestBinaryArithmeticSigned, ModOverflow) { + using CType = typename TestFixture::CType; + auto min = std::numeric_limits::lowest(); + + // Unchecked: returns 0 + this->SetOverflowCheck(false); + this->AssertBinop(Mod, MakeArray(min), MakeArray(CType(-1)), "[0]"); + + // Checked: raises overflow error + this->SetOverflowCheck(true); + this->AssertBinopRaises(Mod, MakeArray(min), MakeArray(CType(-1)), "overflow"); +} + +TYPED_TEST(TestBinaryArithmeticIntegral, ModByZero) { + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + this->AssertBinopRaises(Mod, "[3, 2, 6]", "[1, 1, 0]", "divide by zero"); + } +} + +TYPED_TEST(TestBinaryArithmeticFloating, Mod) { + SKIP_IF_HALF_FLOAT(); + + this->SetNansEqual(true); + + // Basic cases + this->AssertBinop(Mod, "[7.5, 10.0]", "[2.5, 3.0]", "[0.0, 1.0]"); + // Negative numbers - floored semantics: sign follows divisor + this->AssertBinop(Mod, "[-7.5]", "[2.5]", "[0.0]"); + this->AssertBinop(Mod, "[7.5]", "[-2.5]", "[-0.0]"); + this->AssertBinop(Mod, "[-7.5]", "[-2.5]", "[-0.0]"); + + // Division by zero returns NaN (unchecked) + this->SetOverflowCheck(false); + this->AssertBinop(Mod, "[1.0]", "[0.0]", "[NaN]"); + + // Division by zero raises error (checked) + this->SetOverflowCheck(true); + this->AssertBinopRaises(Mod, "[1.0]", "[0.0]", "divide by zero"); + + // Infinity edge cases (unchecked) + this->SetOverflowCheck(false); + this->AssertBinop(Mod, "[Inf]", "[2.0]", "[NaN]"); + this->AssertBinop(Mod, "[-Inf]", "[2.0]", "[NaN]"); + this->AssertBinop(Mod, "[2.0]", "[Inf]", "[2.0]"); + this->AssertBinop(Mod, "[2.0]", "[-Inf]", "[-Inf]"); // floored: 2.0 + (-Inf) = -Inf + this->AssertBinop(Mod, "[Inf]", "[Inf]", "[NaN]"); +} + TYPED_TEST(TestBinaryArithmeticFloating, Power) { SKIP_IF_HALF_FLOAT(); @@ -2404,6 +2571,110 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) { } } +TEST_F(TestBinaryArithmeticDecimal, Remainder) { + // Truncated semantics: sign follows dividend + + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(5, 2), R"(["7.00", "-7.00", "7.00", "-7.00"])"); + auto right = ArrayFromJSON(decimal128(5, 2), R"(["3.00", "3.00", "-3.00", "-3.00"])"); + auto expected = + ArrayFromJSON(decimal128(6, 2), R"(["1.00", "-1.00", "1.00", "-1.00"])"); + CheckScalarBinary("remainder", left, right, expected); + } + + // array array, decimal256 + { + auto left = ArrayFromJSON(decimal256(5, 2), R"(["7.00", "-7.00"])"); + auto right = ArrayFromJSON(decimal256(5, 2), R"(["3.00", "3.00"])"); + auto expected = ArrayFromJSON(decimal256(6, 2), R"(["1.00", "-1.00"])"); + CheckScalarBinary("remainder", left, right, expected); + } + + // scalar scalar + { + auto left = ScalarFromJSON(decimal128(5, 2), R"("17.50")"); + auto right = ScalarFromJSON(decimal128(5, 2), R"("5.00")"); + auto expected = ScalarFromJSON(decimal128(6, 2), R"("2.50")"); + CheckScalarBinary("remainder", left, right, expected); + } + + // failed case: divide by 0 + { + auto left = ScalarFromJSON(decimal256(1, 0), R"("7")"); + auto right = ScalarFromJSON(decimal256(1, 0), R"("0")"); + ASSERT_RAISES(Invalid, CallFunction("remainder", {left, right})); + } + + // mixed precision: different precisions, same scale + { + auto left = ArrayFromJSON(decimal128(5, 2), R"(["17.00", "-17.00"])"); + auto right = ArrayFromJSON(decimal128(3, 2), R"(["5.00", "5.00"])"); + auto expected = ArrayFromJSON(decimal128(6, 2), R"(["2.00", "-2.00"])"); + CheckScalarBinary("remainder", left, right, expected); + } + + // mixed types: decimal128 and decimal256 + { + auto left = ArrayFromJSON(decimal128(5, 2), R"(["17.00"])"); + auto right = ArrayFromJSON(decimal256(5, 2), R"(["5.00"])"); + auto expected = ArrayFromJSON(decimal256(6, 2), R"(["2.00"])"); + CheckScalarBinary("remainder", left, right, expected); + } +} + +TEST_F(TestBinaryArithmeticDecimal, Mod) { + // Floored semantics: sign follows divisor + + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(5, 2), R"(["7.00", "-7.00", "7.00", "-7.00"])"); + auto right = ArrayFromJSON(decimal128(5, 2), R"(["3.00", "3.00", "-3.00", "-3.00"])"); + auto expected = + ArrayFromJSON(decimal128(6, 2), R"(["1.00", "2.00", "-2.00", "-1.00"])"); + CheckScalarBinary("mod", left, right, expected); + } + + // array array, decimal256 + { + auto left = ArrayFromJSON(decimal256(5, 2), R"(["7.00", "-7.00"])"); + auto right = ArrayFromJSON(decimal256(5, 2), R"(["3.00", "3.00"])"); + auto expected = ArrayFromJSON(decimal256(6, 2), R"(["1.00", "2.00"])"); + CheckScalarBinary("mod", left, right, expected); + } + + // scalar scalar + { + auto left = ScalarFromJSON(decimal128(5, 2), R"("-17.50")"); + auto right = ScalarFromJSON(decimal128(5, 2), R"("5.00")"); + auto expected = ScalarFromJSON(decimal128(6, 2), R"("2.50")"); + CheckScalarBinary("mod", left, right, expected); + } + + // failed case: divide by 0 + { + auto left = ScalarFromJSON(decimal256(1, 0), R"("7")"); + auto right = ScalarFromJSON(decimal256(1, 0), R"("0")"); + ASSERT_RAISES(Invalid, CallFunction("mod", {left, right})); + } + + // mixed precision: different precisions, same scale + { + auto left = ArrayFromJSON(decimal128(5, 2), R"(["-17.00"])"); + auto right = ArrayFromJSON(decimal128(3, 2), R"(["5.00"])"); + auto expected = ArrayFromJSON(decimal128(6, 2), R"(["3.00"])"); + CheckScalarBinary("mod", left, right, expected); + } + + // mixed types: decimal128 and decimal256 + { + auto left = ArrayFromJSON(decimal128(5, 2), R"(["-17.00"])"); + auto right = ArrayFromJSON(decimal256(5, 2), R"(["5.00"])"); + auto expected = ArrayFromJSON(decimal256(6, 2), R"(["3.00"])"); + CheckScalarBinary("mod", left, right, expected); + } +} + TEST_F(TestBinaryArithmeticDecimal, Atan2) { // Decimal arguments promoted to double, sanity check here const auto func = "atan2"; diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index eddb1aae7b2..9e3de37f3f9 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -1399,6 +1399,14 @@ BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& ri return result; } +BasicDecimal256 operator%(const BasicDecimal256& left, const BasicDecimal256& right) { + BasicDecimal256 remainder; + BasicDecimal256 result; + auto s = left.Divide(right, &result, &remainder); + DCHECK_EQ(s, DecimalStatus::kSuccess); + return remainder; +} + // Explicitly instantiate template base class, for DLL linking on Windows template class GenericBasicDecimal; template class GenericBasicDecimal; diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index 638c4870f1d..f35696f4ee5 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -883,5 +883,7 @@ ARROW_EXPORT BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right); ARROW_EXPORT BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& right); +ARROW_EXPORT BasicDecimal256 operator%(const BasicDecimal256& left, + const BasicDecimal256& right); } // namespace arrow diff --git a/cpp/src/arrow/util/int_util_overflow.h b/cpp/src/arrow/util/int_util_overflow.h index 69714a935a4..c76d009607f 100644 --- a/cpp/src/arrow/util/int_util_overflow.h +++ b/cpp/src/arrow/util/int_util_overflow.h @@ -137,6 +137,24 @@ template return false; } +template +[[nodiscard]] bool ModuloWithOverflowGeneric(Int u, Int v, Int* out) { + if (v == 0) { + *out = Int{}; + return true; + } + // INT_MIN % -1 causes a hardware trap on x86, but mathematically equals 0 + if constexpr (std::is_signed_v) { + constexpr auto kMin = std::numeric_limits::min(); + if (u == kMin && v == -1) { + *out = 0; + return true; + } + } + *out = u % v; + return false; +} + // Define non-generic versions of the above so as to benefit from automatic // integer conversion, to allow for mixed-type calls such as // AddWithOverflow(int32_t, int64_t, int64_t*). @@ -160,6 +178,7 @@ NON_GENERIC_OPS_WITH_OVERFLOW(AddWithOverflow) NON_GENERIC_OPS_WITH_OVERFLOW(SubtractWithOverflow) NON_GENERIC_OPS_WITH_OVERFLOW(MultiplyWithOverflow) NON_GENERIC_OPS_WITH_OVERFLOW(DivideWithOverflow) +NON_GENERIC_OPS_WITH_OVERFLOW(ModuloWithOverflow) #undef NON_GENERIC_OPS_WITH_OVERFLOW #undef NON_GENERIC_OP_WITH_OVERFLOW diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index e4092af70cd..be2f73e3db7 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -487,7 +487,7 @@ overflow-checking variant, suffixed ``_checked``, which returns an ``Invalid`` :class:`Status` when overflow is detected. For functions which support decimal inputs (currently ``add``, ``subtract``, -``multiply``, and ``divide`` and their checked variants), decimals of different +``multiply``, ``divide``, ``mod``, and ``remainder`` and their checked variants), decimals of different precisions/scales will be promoted appropriately. Mixed decimal and floating-point arguments will cast all arguments to floating-point, while mixed decimal and integer arguments will cast all arguments to decimals. @@ -516,6 +516,10 @@ Mixed time resolution temporal inputs will be cast to finest input resolution. +------------------+--------+-------------------------+-------------------------------+-------+ | multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | +------------------+--------+-------------------------+-------------------------------+-------+ +| mod | Binary | Numeric | Numeric | \(3) | ++------------------+--------+-------------------------+-------------------------------+-------+ +| mod_checked | Binary | Numeric | Numeric | \(3) | ++------------------+--------+-------------------------+-------------------------------+-------+ | negate | Unary | Numeric/Duration | Numeric/Duration | | +------------------+--------+-------------------------+-------------------------------+-------+ | negate_checked | Unary | Signed Numeric/Duration | Signed Numeric/Duration | | @@ -524,6 +528,10 @@ Mixed time resolution temporal inputs will be cast to finest input resolution. +------------------+--------+-------------------------+-------------------------------+-------+ | power_checked | Binary | Numeric | Numeric | | +------------------+--------+-------------------------+-------------------------------+-------+ +| remainder | Binary | Numeric | Numeric | \(4) | ++------------------+--------+-------------------------+-------------------------------+-------+ +| remainder_checked| Binary | Numeric | Numeric | \(4) | ++------------------+--------+-------------------------+-------------------------------+-------+ | sign | Unary | Numeric/Duration | Int8/Float16/Float32/Float64 | \(2) | +------------------+--------+-------------------------+-------------------------------+-------+ | sqrt | Unary | Numeric | Numeric | | @@ -560,6 +568,14 @@ Mixed time resolution temporal inputs will be cast to finest input resolution. values return NaN. Integral and decimal values return signedness as Int8 and floating-point values return it with the same type as the input values. +* \(3) Computes the floored modulo operation where the result has the same sign + as the divisor. This is equivalent to Python's ``%`` operator. For decimals, + the result uses the same precision/scale promotion as ``add``. + +* \(4) Computes the truncated remainder where the result has the same sign + as the dividend. This is equivalent to C/C++'s ``%`` operator. For decimals, + the result uses the same precision/scale promotion as ``add``. + Bit-wise functions ~~~~~~~~~~~~~~~~~~