-
Notifications
You must be signed in to change notification settings - Fork 0
20099: Add Decimal support for floor preimage #224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,9 +19,10 @@ use std::any::Any; | |||||
| use std::sync::Arc; | ||||||
|
|
||||||
| use arrow::array::{ArrayRef, AsArray}; | ||||||
| use arrow::compute::{DecimalCast, rescale_decimal}; | ||||||
| use arrow::datatypes::{ | ||||||
| DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, | ||||||
| Float64Type, | ||||||
| ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, | ||||||
| Decimal256Type, DecimalType, Float32Type, Float64Type, | ||||||
| }; | ||||||
| use datafusion_common::{Result, ScalarValue, exec_err}; | ||||||
| use datafusion_expr::interval_arithmetic::Interval; | ||||||
|
|
@@ -230,8 +231,6 @@ impl ScalarUDFImpl for FloorFunc { | |||||
|
|
||||||
| // Compute lower bound (N) and upper bound (N + 1) using helper functions | ||||||
| let Some((lower, upper)) = (match lit_value { | ||||||
| // Decimal types should be supported and tracked in | ||||||
| // https://github.com/apache/datafusion/issues/20080 | ||||||
| // Floating-point types | ||||||
| ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| { | ||||||
| ( | ||||||
|
|
@@ -260,6 +259,48 @@ impl ScalarUDFImpl for FloorFunc { | |||||
| (ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi))) | ||||||
| }), | ||||||
|
|
||||||
| // Decimal types | ||||||
| ScalarValue::Decimal32(Some(n), precision, scale) => { | ||||||
| decimal_preimage_bounds::<Decimal32Type>(*n, *precision, *scale).map( | ||||||
| |(lo, hi)| { | ||||||
| ( | ||||||
| ScalarValue::Decimal32(Some(lo), *precision, *scale), | ||||||
| ScalarValue::Decimal32(Some(hi), *precision, *scale), | ||||||
| ) | ||||||
| }, | ||||||
| ) | ||||||
| } | ||||||
| ScalarValue::Decimal64(Some(n), precision, scale) => { | ||||||
| decimal_preimage_bounds::<Decimal64Type>(*n, *precision, *scale).map( | ||||||
| |(lo, hi)| { | ||||||
| ( | ||||||
| ScalarValue::Decimal64(Some(lo), *precision, *scale), | ||||||
| ScalarValue::Decimal64(Some(hi), *precision, *scale), | ||||||
| ) | ||||||
| }, | ||||||
| ) | ||||||
| } | ||||||
| ScalarValue::Decimal128(Some(n), precision, scale) => { | ||||||
| decimal_preimage_bounds::<Decimal128Type>(*n, *precision, *scale).map( | ||||||
| |(lo, hi)| { | ||||||
| ( | ||||||
| ScalarValue::Decimal128(Some(lo), *precision, *scale), | ||||||
| ScalarValue::Decimal128(Some(hi), *precision, *scale), | ||||||
| ) | ||||||
| }, | ||||||
| ) | ||||||
| } | ||||||
| ScalarValue::Decimal256(Some(n), precision, scale) => { | ||||||
| decimal_preimage_bounds::<Decimal256Type>(*n, *precision, *scale).map( | ||||||
| |(lo, hi)| { | ||||||
| ( | ||||||
| ScalarValue::Decimal256(Some(lo), *precision, *scale), | ||||||
| ScalarValue::Decimal256(Some(hi), *precision, *scale), | ||||||
| ) | ||||||
| }, | ||||||
| ) | ||||||
| } | ||||||
|
|
||||||
| // Unsupported types | ||||||
| _ => None, | ||||||
| }) else { | ||||||
|
|
@@ -310,9 +351,45 @@ fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> { | |||||
| Some((n, upper)) | ||||||
| } | ||||||
|
|
||||||
| /// Compute preimage bounds for floor function on decimal types. | ||||||
| /// For floor(x) = n, the preimage is [n, n+1). | ||||||
| /// Returns None if: | ||||||
| /// - The value has a fractional part (floor always returns integers) | ||||||
| /// - Adding 1 would overflow | ||||||
| fn decimal_preimage_bounds<D: DecimalType>( | ||||||
| value: D::Native, | ||||||
| precision: u8, | ||||||
| scale: i8, | ||||||
| ) -> Option<(D::Native, D::Native)> | ||||||
| where | ||||||
| D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>, | ||||||
| { | ||||||
| // Use rescale_decimal to compute "1" at target scale (avoids manual pow) | ||||||
| // Convert integer 1 (scale=0) to the target scale | ||||||
| let one_scaled: D::Native = rescale_decimal::<D, D>( | ||||||
| D::Native::ONE, // value = 1 | ||||||
| 1, // input_precision = 1 | ||||||
| 0, // input_scale = 0 (integer) | ||||||
| precision, // output_precision | ||||||
| scale, // output_scale | ||||||
| )?; | ||||||
|
|
||||||
| // floor always returns an integer, so if value has a fractional part, there's no solution | ||||||
| // Check: value % one_scaled != 0 means fractional part exists | ||||||
| if scale > 0 && value % one_scaled != D::Native::ZERO { | ||||||
| return None; | ||||||
| } | ||||||
|
|
||||||
| // Compute upper bound using checked addition | ||||||
| let upper = value.add_checked(one_scaled).ok()?; | ||||||
|
|
||||||
| Some((value, upper)) | ||||||
| } | ||||||
|
|
||||||
| #[cfg(test)] | ||||||
| mod tests { | ||||||
| use super::*; | ||||||
| use arrow_buffer::i256; | ||||||
| use datafusion_expr::col; | ||||||
|
|
||||||
| /// Helper to test valid preimage cases that should return a Range | ||||||
|
|
@@ -463,4 +540,240 @@ mod tests { | |||||
| "Expected None for zero args" | ||||||
| ); | ||||||
| } | ||||||
|
|
||||||
| // ============ Decimal32 Tests (mirrors float/int tests) ============ | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_floor_preimage_decimal_valid_cases() { | ||||||
| // ===== Decimal32 ===== | ||||||
| // Positive integer decimal: 100.00 (scale=2, so raw=10000) | ||||||
| // floor(x) = 100.00 -> x in [100.00, 101.00) | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(10000), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00 | ||||||
| ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00 | ||||||
| ); | ||||||
|
|
||||||
| // Smaller positive: 50.00 | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(5000), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00 | ||||||
| ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00 | ||||||
| ); | ||||||
|
|
||||||
| // Negative integer decimal: -5.00 | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(-500), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00 | ||||||
| ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00 | ||||||
| ); | ||||||
|
|
||||||
| // Zero: 0.00 | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(0), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(0), 9, 2), // 0.00 | ||||||
| ScalarValue::Decimal32(Some(100), 9, 2), // 1.00 | ||||||
| ); | ||||||
|
|
||||||
| // Scale 0 (pure integer): 42 | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(42), 9, 0), | ||||||
| ScalarValue::Decimal32(Some(42), 9, 0), | ||||||
| ScalarValue::Decimal32(Some(43), 9, 0), | ||||||
| ); | ||||||
|
|
||||||
| // ===== Decimal64 ===== | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal64(Some(10000), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00 | ||||||
| ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00 | ||||||
| ); | ||||||
|
|
||||||
| // Negative | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal64(Some(-500), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00 | ||||||
| ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00 | ||||||
| ); | ||||||
|
|
||||||
| // Zero | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal64(Some(0), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(0), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(100), 18, 2), | ||||||
| ); | ||||||
|
|
||||||
| // ===== Decimal128 ===== | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal128(Some(10000), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00 | ||||||
| ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00 | ||||||
| ); | ||||||
|
|
||||||
| // Negative | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal128(Some(-500), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00 | ||||||
| ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00 | ||||||
| ); | ||||||
|
|
||||||
| // Zero | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal128(Some(0), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(0), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(100), 38, 2), | ||||||
| ); | ||||||
|
|
||||||
| // ===== Decimal256 ===== | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00 | ||||||
| ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00 | ||||||
| ); | ||||||
|
|
||||||
| // Negative | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00 | ||||||
| ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00 | ||||||
| ); | ||||||
|
|
||||||
| // Zero | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(i256::from(100)), 76, 2), | ||||||
| ); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_floor_preimage_decimal_non_integer() { | ||||||
| // floor(x) = 1.30 has NO SOLUTION because floor always returns an integer | ||||||
| // Therefore preimage should return None for non-integer decimals | ||||||
|
|
||||||
| // Decimal32 | ||||||
| assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30 | ||||||
| assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50 | ||||||
| assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70 | ||||||
| assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01 | ||||||
|
|
||||||
| // Decimal64 | ||||||
| assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30 | ||||||
| assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50 | ||||||
|
|
||||||
| // Decimal128 | ||||||
| assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30 | ||||||
| assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50 | ||||||
|
|
||||||
| // Decimal256 | ||||||
| assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30 | ||||||
| assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50 | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_floor_preimage_decimal_overflow() { | ||||||
| // Test near MAX where adding scale_factor would overflow | ||||||
|
|
||||||
| // Decimal32: i32::MAX | ||||||
| // For scale=2, we add 100, so i32::MAX - 50 would overflow | ||||||
| assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 9, 2)); | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These scale=2 “overflow” cases (e.g. Other Locations
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback: The Augment AI reviewer is correct! i32::MAX - 50 gives 2147483597, which has 10 digits and does not fit into precision=9 at all. So, the decimal number is overflown by construction and there is no call to floor() at all |
||||||
| // For scale=0, we add 1, so i32::MAX would overflow | ||||||
| assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 9, 0)); | ||||||
|
|
||||||
| // Decimal64: i64::MAX | ||||||
| assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 18, 2)); | ||||||
| assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 18, 0)); | ||||||
|
|
||||||
| // Decimal128: i128::MAX | ||||||
| assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX - 50), 38, 2)); | ||||||
| assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX), 38, 0)); | ||||||
|
|
||||||
| // Decimal256: i256::MAX | ||||||
| assert_preimage_none(ScalarValue::Decimal256( | ||||||
| Some(i256::MAX.wrapping_sub(i256::from(50))), | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:annoying; category:bug; feedback: The Gemini AI reviewer is not correct! There is no implementation of std::ops::Sub for i256 in Apache Arrow |
||||||
| 76, | ||||||
| 2, | ||||||
| )); | ||||||
| assert_preimage_none(ScalarValue::Decimal256(Some(i256::MAX), 76, 0)); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_floor_preimage_decimal_edge_cases() { | ||||||
| // ===== Decimal32 ===== | ||||||
| // Large value that doesn't overflow | ||||||
| // i32::MAX = 2147483647, with scale=2, max safe is around i32::MAX - 100 | ||||||
| let safe_max_32 = i32::MAX - 100; | ||||||
| // Make it divisible by 100 for scale=2 | ||||||
| let safe_max_aligned_32 = (safe_max_32 / 100) * 100; | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2), | ||||||
| ); | ||||||
|
|
||||||
| // Negative edge: i32::MIN should work since we're adding (not subtracting) | ||||||
| let min_aligned_32 = (i32::MIN / 100) * 100; | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), | ||||||
| ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2), | ||||||
| ); | ||||||
|
|
||||||
| // ===== Decimal64 ===== | ||||||
| let safe_max_64 = i64::MAX - 100; | ||||||
| let safe_max_aligned_64 = (safe_max_64 / 100) * 100; | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(safe_max_aligned_64 + 100), 18, 2), | ||||||
| ); | ||||||
|
|
||||||
| let min_aligned_64 = (i64::MIN / 100) * 100; | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal64(Some(min_aligned_64), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(min_aligned_64), 18, 2), | ||||||
| ScalarValue::Decimal64(Some(min_aligned_64 + 100), 18, 2), | ||||||
| ); | ||||||
|
|
||||||
| // ===== Decimal128 ===== | ||||||
| let safe_max_128 = i128::MAX - 100; | ||||||
| let safe_max_aligned_128 = (safe_max_128 / 100) * 100; | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(safe_max_aligned_128 + 100), 38, 2), | ||||||
| ); | ||||||
|
|
||||||
| let min_aligned_128 = (i128::MIN / 100) * 100; | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal128(Some(min_aligned_128), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(min_aligned_128), 38, 2), | ||||||
| ScalarValue::Decimal128(Some(min_aligned_128 + 100), 38, 2), | ||||||
| ); | ||||||
|
|
||||||
| // ===== Decimal256 ===== | ||||||
| // For i256, we use smaller values since MAX is huge | ||||||
| let large_256 = i256::from(1_000_000_000_000i64); | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal256(Some(large_256), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(large_256), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(large_256.wrapping_add(i256::from(100))), 76, 2), | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For clarity and consistency with
Suggested change
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:annoying; category:bug; feedback: The Gemini AI reviewer is not correct! There is no implementation of std::ops::Add for i256 in Apache Arrow |
||||||
| ); | ||||||
|
|
||||||
| // Negative i256 | ||||||
| let neg_256 = i256::from(-1_000_000_000_000i64); | ||||||
| assert_preimage_range( | ||||||
| ScalarValue::Decimal256(Some(neg_256), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(neg_256), 76, 2), | ||||||
| ScalarValue::Decimal256(Some(neg_256.wrapping_add(i256::from(100))), 76, 2), | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:annoying; category:bug; feedback: The Gemini AI reviewer is not correct! There is no implementation of std::ops::Add for i256 in Apache Arrow |
||||||
| ); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_floor_preimage_decimal_null() { | ||||||
| assert_preimage_none(ScalarValue::Decimal32(None, 9, 2)); | ||||||
| assert_preimage_none(ScalarValue::Decimal64(None, 18, 2)); | ||||||
| assert_preimage_none(ScalarValue::Decimal128(None, 38, 2)); | ||||||
| assert_preimage_none(ScalarValue::Decimal256(None, 76, 2)); | ||||||
| } | ||||||
| } | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These match arms for
Decimal32,Decimal64,Decimal128, andDecimal256are very repetitive. This code duplication could be reduced to improve maintainability.Consider defining a local macro within the
preimagefunction to handle the repetitive logic. For example:Then each match arm could be simplified to a single line, e.g.:
ScalarValue::Decimal32(Some(n), precision, scale) => handle_decimal!(Decimal32Type, n, precision, scale, ScalarValue::Decimal32)Since adding the macro definition is outside the changed lines, I'm not providing a direct code suggestion, but this refactoring would make the code more concise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:good-to-have; category:bug; feedback: The Gemini AI reviewer is correct! The match arms are identical and differ only in the type of the handled Decimal type. Prevents code duplication and double maintenance.