diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index 7c7604b7fd88..f4c85032ca51 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -19,9 +19,10 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::{DecimalCast, rescale_decimal}; use arrow::datatypes::{ - DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, - Float64Type, + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, }; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::interval_arithmetic::Interval; @@ -230,8 +231,6 @@ impl ScalarUDFImpl for FloorFunc { // Compute lower bound (N) and upper bound (N + 1) using helper functions let Some((lower, upper)) = (match lit_value { - // Decimal types should be supported and tracked in - // https://github.com/apache/datafusion/issues/20080 // Floating-point types ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| { ( @@ -260,6 +259,48 @@ impl ScalarUDFImpl for FloorFunc { (ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi))) }), + // Decimal types + ScalarValue::Decimal32(Some(n), precision, scale) => { + decimal_preimage_bounds::(*n, *precision, *scale).map( + |(lo, hi)| { + ( + ScalarValue::Decimal32(Some(lo), *precision, *scale), + ScalarValue::Decimal32(Some(hi), *precision, *scale), + ) + }, + ) + } + ScalarValue::Decimal64(Some(n), precision, scale) => { + decimal_preimage_bounds::(*n, *precision, *scale).map( + |(lo, hi)| { + ( + ScalarValue::Decimal64(Some(lo), *precision, *scale), + ScalarValue::Decimal64(Some(hi), *precision, *scale), + ) + }, + ) + } + ScalarValue::Decimal128(Some(n), precision, scale) => { + decimal_preimage_bounds::(*n, *precision, *scale).map( + |(lo, hi)| { + ( + ScalarValue::Decimal128(Some(lo), *precision, *scale), + ScalarValue::Decimal128(Some(hi), *precision, *scale), + ) + }, + ) + } + ScalarValue::Decimal256(Some(n), precision, scale) => { + decimal_preimage_bounds::(*n, *precision, *scale).map( + |(lo, hi)| { + ( + ScalarValue::Decimal256(Some(lo), *precision, *scale), + ScalarValue::Decimal256(Some(hi), *precision, *scale), + ) + }, + ) + } + // Unsupported types _ => None, }) else { @@ -310,9 +351,45 @@ fn int_preimage_bounds(n: I) -> Option<(I, I)> { Some((n, upper)) } +/// Compute preimage bounds for floor function on decimal types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value has a fractional part (floor always returns integers) +/// - Adding 1 would overflow +fn decimal_preimage_bounds( + value: D::Native, + precision: u8, + scale: i8, +) -> Option<(D::Native, D::Native)> +where + D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem, +{ + // Use rescale_decimal to compute "1" at target scale (avoids manual pow) + // Convert integer 1 (scale=0) to the target scale + let one_scaled: D::Native = rescale_decimal::( + D::Native::ONE, // value = 1 + 1, // input_precision = 1 + 0, // input_scale = 0 (integer) + precision, // output_precision + scale, // output_scale + )?; + + // floor always returns an integer, so if value has a fractional part, there's no solution + // Check: value % one_scaled != 0 means fractional part exists + if scale > 0 && value % one_scaled != D::Native::ZERO { + return None; + } + + // Compute upper bound using checked addition + let upper = value.add_checked(one_scaled).ok()?; + + Some((value, upper)) +} + #[cfg(test)] mod tests { use super::*; + use arrow_buffer::i256; use datafusion_expr::col; /// Helper to test valid preimage cases that should return a Range @@ -463,4 +540,240 @@ mod tests { "Expected None for zero args" ); } + + // ============ Decimal32 Tests (mirrors float/int tests) ============ + + #[test] + fn test_floor_preimage_decimal_valid_cases() { + // ===== Decimal32 ===== + // Positive integer decimal: 100.00 (scale=2, so raw=10000) + // floor(x) = 100.00 -> x in [100.00, 101.00) + assert_preimage_range( + ScalarValue::Decimal32(Some(10000), 9, 2), + ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00 + ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00 + ); + + // Smaller positive: 50.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(5000), 9, 2), + ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00 + ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00 + ); + + // Negative integer decimal: -5.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(-500), 9, 2), + ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00 + ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00 + ); + + // Zero: 0.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(0), 9, 2), + ScalarValue::Decimal32(Some(0), 9, 2), // 0.00 + ScalarValue::Decimal32(Some(100), 9, 2), // 1.00 + ); + + // Scale 0 (pure integer): 42 + assert_preimage_range( + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(43), 9, 0), + ); + + // ===== Decimal64 ===== + assert_preimage_range( + ScalarValue::Decimal64(Some(10000), 18, 2), + ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00 + ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal64(Some(-500), 18, 2), + ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00 + ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(100), 18, 2), + ); + + // ===== Decimal128 ===== + assert_preimage_range( + ScalarValue::Decimal128(Some(10000), 38, 2), + ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00 + ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal128(Some(-500), 38, 2), + ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00 + ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(100), 38, 2), + ); + + // ===== Decimal256 ===== + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00 + ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00 + ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::from(100)), 76, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_non_integer() { + // floor(x) = 1.30 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer decimals + + // Decimal32 + assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50 + assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70 + assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01 + + // Decimal64 + assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50 + + // Decimal128 + assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50 + + // Decimal256 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50 + } + + #[test] + fn test_floor_preimage_decimal_overflow() { + // Test near MAX where adding scale_factor would overflow + + // Decimal32: i32::MAX + // For scale=2, we add 100, so i32::MAX - 50 would overflow + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 9, 2)); + // For scale=0, we add 1, so i32::MAX would overflow + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 9, 0)); + + // Decimal64: i64::MAX + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 18, 2)); + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 18, 0)); + + // Decimal128: i128::MAX + assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX - 50), 38, 2)); + assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX), 38, 0)); + + // Decimal256: i256::MAX + assert_preimage_none(ScalarValue::Decimal256( + Some(i256::MAX.wrapping_sub(i256::from(50))), + 76, + 2, + )); + assert_preimage_none(ScalarValue::Decimal256(Some(i256::MAX), 76, 0)); + } + + #[test] + fn test_floor_preimage_decimal_edge_cases() { + // ===== Decimal32 ===== + // Large value that doesn't overflow + // i32::MAX = 2147483647, with scale=2, max safe is around i32::MAX - 100 + let safe_max_32 = i32::MAX - 100; + // Make it divisible by 100 for scale=2 + let safe_max_aligned_32 = (safe_max_32 / 100) * 100; + assert_preimage_range( + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2), + ); + + // Negative edge: i32::MIN should work since we're adding (not subtracting) + let min_aligned_32 = (i32::MIN / 100) * 100; + assert_preimage_range( + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2), + ); + + // ===== Decimal64 ===== + let safe_max_64 = i64::MAX - 100; + let safe_max_aligned_64 = (safe_max_64 / 100) * 100; + assert_preimage_range( + ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2), + ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2), + ScalarValue::Decimal64(Some(safe_max_aligned_64 + 100), 18, 2), + ); + + let min_aligned_64 = (i64::MIN / 100) * 100; + assert_preimage_range( + ScalarValue::Decimal64(Some(min_aligned_64), 18, 2), + ScalarValue::Decimal64(Some(min_aligned_64), 18, 2), + ScalarValue::Decimal64(Some(min_aligned_64 + 100), 18, 2), + ); + + // ===== Decimal128 ===== + let safe_max_128 = i128::MAX - 100; + let safe_max_aligned_128 = (safe_max_128 / 100) * 100; + assert_preimage_range( + ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2), + ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2), + ScalarValue::Decimal128(Some(safe_max_aligned_128 + 100), 38, 2), + ); + + let min_aligned_128 = (i128::MIN / 100) * 100; + assert_preimage_range( + ScalarValue::Decimal128(Some(min_aligned_128), 38, 2), + ScalarValue::Decimal128(Some(min_aligned_128), 38, 2), + ScalarValue::Decimal128(Some(min_aligned_128 + 100), 38, 2), + ); + + // ===== Decimal256 ===== + // For i256, we use smaller values since MAX is huge + let large_256 = i256::from(1_000_000_000_000i64); + assert_preimage_range( + ScalarValue::Decimal256(Some(large_256), 76, 2), + ScalarValue::Decimal256(Some(large_256), 76, 2), + ScalarValue::Decimal256(Some(large_256.wrapping_add(i256::from(100))), 76, 2), + ); + + // Negative i256 + let neg_256 = i256::from(-1_000_000_000_000i64); + assert_preimage_range( + ScalarValue::Decimal256(Some(neg_256), 76, 2), + ScalarValue::Decimal256(Some(neg_256), 76, 2), + ScalarValue::Decimal256(Some(neg_256.wrapping_add(i256::from(100))), 76, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_null() { + assert_preimage_none(ScalarValue::Decimal32(None, 9, 2)); + assert_preimage_none(ScalarValue::Decimal64(None, 18, 2)); + assert_preimage_none(ScalarValue::Decimal128(None, 38, 2)); + assert_preimage_none(ScalarValue::Decimal256(None, 76, 2)); + } } diff --git a/datafusion/sqllogictest/test_files/floor_preimage.slt b/datafusion/sqllogictest/test_files/floor_preimage.slt new file mode 100644 index 000000000000..67840d362335 --- /dev/null +++ b/datafusion/sqllogictest/test_files/floor_preimage.slt @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Floor Preimage Tests +## +## Tests for floor function preimage optimization: +## floor(col) = N transforms to col >= N AND col < N + 1 +## +## Uses representative types only (Float64, Int32, Decimal128). +## Unit tests cover all type variants. +########## + +# Setup: Single table with representative types +statement ok +CREATE TABLE test_data ( + id INT, + float_val DOUBLE, + int_val INT, + decimal_val DECIMAL(10,2) +) AS VALUES + (1, 5.3, 100, 100.00), + (2, 5.7, 101, 100.50), + (3, 6.0, 102, 101.00), + (4, 6.5, -5, 101.99), + (5, 7.0, 0, 102.00), + (6, NULL, NULL, NULL); + +########## +## Data Correctness Tests +########## + +# Float64: floor(x) = 5 matches values in [5.0, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5, 'Float64'); +---- +1 +2 + +# Int32: floor(x) = 100 matches values in [100, 101) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = 100; +---- +1 + +# Decimal128: floor(x) = 100 matches values in [100.00, 101.00) +query I rowsort +SELECT id FROM test_data WHERE floor(decimal_val) = arrow_cast(100, 'Decimal128(10,2)'); +---- +1 +2 + +# Negative value: floor(x) = -5 matches values in [-5, -4) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = -5; +---- +4 + +# Zero value: floor(x) = 0 matches values in [0, 1) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = 0; +---- +5 + +# Column on RHS (same result as LHS) +query I rowsort +SELECT id FROM test_data WHERE arrow_cast(5, 'Float64') = floor(float_val); +---- +1 +2 + +# IS NOT DISTINCT FROM (excludes NULLs) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IS NOT DISTINCT FROM arrow_cast(5, 'Float64'); +---- +1 +2 + +# IS DISTINCT FROM (includes NULLs) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IS DISTINCT FROM arrow_cast(5, 'Float64'); +---- +3 +4 +5 +6 + +# Non-integer literal (empty result - floor returns integers) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64'); +---- + +########## +## EXPLAIN Tests - Plan Optimization +########## + +statement ok +set datafusion.explain.logical_plan_only = true; + +# 1. Basic: Float64 - floor(col) = N transforms to col >= N AND col < N+1 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 2. Basic: Int32 - transformed (coerced to Float64) +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(int_val) = 100; +---- +logical_plan +01)Projection: test_data.id, test_data.float_val, test_data.int_val, test_data.decimal_val +02)--Filter: __common_expr_3 >= Float64(100) AND __common_expr_3 < Float64(101) +03)----Projection: CAST(test_data.int_val AS Float64) AS __common_expr_3, test_data.id, test_data.float_val, test_data.int_val, test_data.decimal_val +04)------TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 3. Basic: Decimal128 - same transformation +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(decimal_val) = arrow_cast(100, 'Decimal128(10,2)'); +---- +logical_plan +01)Filter: test_data.decimal_val >= Decimal128(Some(10000),10,2) AND test_data.decimal_val < Decimal128(Some(10100),10,2) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 4. Column on RHS - same transformation +query TT +EXPLAIN SELECT * FROM test_data WHERE arrow_cast(5, 'Float64') = floor(float_val); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 5. IS NOT DISTINCT FROM - adds IS NOT NULL +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IS NOT DISTINCT FROM arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val IS NOT NULL AND test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 6. IS DISTINCT FROM - includes NULL check +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IS DISTINCT FROM arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(5) OR test_data.float_val >= Float64(6) OR test_data.float_val IS NULL +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 7. Non-optimizable: non-integer literal (original predicate preserved) +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64'); +---- +logical_plan +01)Filter: floor(test_data.float_val) = Float64(5.5) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +########## +## Other Comparison Operators +## +## The preimage framework automatically handles all comparison operators: +## floor(x) <> N -> x < N OR x >= N+1 +## floor(x) > N -> x >= N+1 +## floor(x) < N -> x < N +## floor(x) >= N -> x >= N +## floor(x) <= N -> x < N+1 +########## + +# Data correctness tests for other operators + +# Not equals: floor(x) <> 5 matches values outside [5.0, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) <> arrow_cast(5, 'Float64'); +---- +3 +4 +5 + +# Greater than: floor(x) > 5 matches values in [6.0, inf) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) > arrow_cast(5, 'Float64'); +---- +3 +4 +5 + +# Less than: floor(x) < 6 matches values in (-inf, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) < arrow_cast(6, 'Float64'); +---- +1 +2 + +# Greater than or equal: floor(x) >= 5 matches values in [5.0, inf) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) >= arrow_cast(5, 'Float64'); +---- +1 +2 +3 +4 +5 + +# Less than or equal: floor(x) <= 5 matches values in (-inf, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) <= arrow_cast(5, 'Float64'); +---- +1 +2 + +# EXPLAIN tests showing optimized transformations + +# Not equals: floor(x) <> 5 -> x < 5 OR x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) <> arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(5) OR test_data.float_val >= Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Greater than: floor(x) > 5 -> x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) > arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Less than: floor(x) < 6 -> x < 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) < arrow_cast(6, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Greater than or equal: floor(x) >= 5 -> x >= 5 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) >= arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Less than or equal: floor(x) <= 5 -> x < 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) <= arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +########## +## Cleanup +########## + +statement ok +DROP TABLE test_data;