Skip to content

Commit d0193a1

Browse files
committed
feat: Allow pow with negative & non-integer exponent on decimals
1 parent 59dcc36 commit d0193a1

File tree

2 files changed

+160
-16
lines changed

2 files changed

+160
-16
lines changed

datafusion/functions/src/math/power.rs

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ impl PowerFunc {
112112
/// 2.5 is represented as 25 with scale 1
113113
/// The unscaled result is 25^4 = 390625
114114
/// Scale it back to 1: 390625 / 10^4 = 39
115-
///
116-
/// Returns error if base is invalid
117115
fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118116
where
119117
T: From<i32> + ArrowNativeTypeOp,
120118
{
119+
if exp < 0 {
120+
return pow_decimal_float(base, scale, exp as f64);
121+
}
122+
121123
let scale: u32 = scale.try_into().map_err(|_| {
122124
ArrowError::NotYetImplemented(format!(
123125
"Negative scale is not yet supported value: {scale}"
@@ -149,22 +151,118 @@ where
149151

150152
/// Binary function to calculate a math power to float exponent
151153
/// for scaled integer types.
152-
/// Returns error if exponent is negative or non-integer, or base invalid
153154
fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
154155
where
155156
T: From<i32> + ArrowNativeTypeOp,
156157
{
157-
if !exp.is_finite() || exp.trunc() != exp {
158+
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
159+
return pow_decimal_int(base, scale, exp as i64);
160+
}
161+
162+
if !exp.is_finite() {
158163
return Err(ArrowError::ComputeError(format!(
159-
"Cannot use non-integer exp: {exp}"
164+
"Cannot use non-finite exp: {exp}"
165+
)));
166+
}
167+
168+
pow_decimal_float_fallback(base, scale, exp)
169+
}
170+
171+
/// Fallback implementation using f64 for negative or non-integer exponents.
172+
/// This handles cases that cannot be computed using integer arithmetic.
173+
fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
174+
where
175+
T: From<i32> + ArrowNativeTypeOp,
176+
{
177+
if scale < 0 {
178+
return Err(ArrowError::NotYetImplemented(format!(
179+
"Negative scale is not yet supported: {scale}"
180+
)));
181+
}
182+
183+
let scale_factor = 10f64.powi(scale as i32);
184+
let base_f64 = format!("{base:?}")
185+
.parse::<f64>()
186+
.map(|v| v / scale_factor)
187+
.map_err(|_| {
188+
ArrowError::ComputeError(format!("Cannot convert base {base:?} to f64"))
189+
})?;
190+
191+
let result_f64 = base_f64.powf(exp);
192+
193+
if !result_f64.is_finite() {
194+
return Err(ArrowError::ArithmeticOverflow(format!(
195+
"Result of {base_f64}^{exp} is not finite"
160196
)));
161197
}
162-
if exp < 0f64 || exp >= u32::MAX as f64 {
198+
199+
let result_scaled = result_f64 * scale_factor;
200+
let result_rounded = result_scaled.round();
201+
202+
if result_rounded.abs() > i128::MAX as f64 {
163203
return Err(ArrowError::ArithmeticOverflow(format!(
164-
"Unsupported exp value: {exp}"
204+
"Result {result_rounded} is too large for the target decimal type"
165205
)));
166206
}
167-
pow_decimal_int(base, scale, exp as i64)
207+
208+
decimal_from_i128::<T>(result_rounded as i128)
209+
}
210+
211+
fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
212+
where
213+
T: From<i32> + ArrowNativeTypeOp,
214+
{
215+
if value == 0 {
216+
return Ok(T::from(0));
217+
}
218+
219+
if value >= i32::MIN as i128 && value <= i32::MAX as i128 {
220+
return Ok(T::from(value as i32));
221+
}
222+
223+
let is_negative = value < 0;
224+
let abs_value = value.unsigned_abs();
225+
226+
let billion = 1_000_000_000u128;
227+
let mut result = T::from(0);
228+
let mut multiplier = T::from(1);
229+
let billion_t = T::from(1_000_000_000);
230+
231+
let mut remaining = abs_value;
232+
while remaining > 0 {
233+
let chunk = (remaining % billion) as i32;
234+
remaining /= billion;
235+
236+
let chunk_value = T::from(chunk).mul_checked(multiplier).map_err(|_| {
237+
ArrowError::ArithmeticOverflow(format!(
238+
"Overflow while converting {value} to decimal type"
239+
))
240+
})?;
241+
242+
result = result.add_checked(chunk_value).map_err(|_| {
243+
ArrowError::ArithmeticOverflow(format!(
244+
"Overflow while converting {value} to decimal type"
245+
))
246+
})?;
247+
248+
if remaining > 0 {
249+
multiplier = multiplier.mul_checked(billion_t).map_err(|_| {
250+
ArrowError::ArithmeticOverflow(format!(
251+
"Overflow while converting {value} to decimal type"
252+
))
253+
})?;
254+
}
255+
}
256+
257+
if is_negative {
258+
result = T::from(0).sub_checked(result).map_err(|_| {
259+
ArrowError::ArithmeticOverflow(format!(
260+
"Overflow while negating {value} in decimal type"
261+
))
262+
})?;
263+
}
264+
265+
Ok(result)
168266
}
169267

170268
impl ScalarUDFImpl for PowerFunc {
@@ -392,4 +490,38 @@ mod tests {
392490
"Not yet implemented: Negative scale is not yet supported value: -1"
393491
);
394492
}
493+
494+
#[test]
495+
fn test_pow_decimal_float_fallback() {
496+
// Test negative exponent: 4^(-1) = 0.25
497+
// 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
498+
let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
499+
assert_eq!(result, 25);
500+
501+
// Test non-integer exponent: 4^0.5 = 2
502+
// 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
503+
let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
504+
assert_eq!(result, 200);
505+
506+
// Test 8^(1/3) = 2 (cube root)
507+
// 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
508+
let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
509+
assert_eq!(result, 20);
510+
511+
// Test negative base with integer exponent still works
512+
// (-2)^3 = -8
513+
// -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
514+
let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
515+
assert_eq!(result, -80);
516+
517+
// Test positive integer exponent goes through fast path
518+
// 2.5^4 = 39.0625
519+
// 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated
520+
let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
521+
assert_eq!(result, 390); // Uses integer path
522+
523+
// Test non-finite exponent returns error
524+
assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
525+
assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
526+
}
395527
}

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -954,8 +954,17 @@ SELECT power(2, 100000000000)
954954
----
955955
Infinity
956956

957-
query error Arrow error: Arithmetic overflow: Unsupported exp value
958-
SELECT power(2::decimal(38, 0), -5)
957+
# Negative exponent now works (fallback to f64)
958+
query RT
959+
SELECT power(2::decimal(38, 0), -5), arrow_typeof(power(2::decimal(38, 0), -5));
960+
----
961+
0 Decimal128(38, 0)
962+
963+
# Negative exponent with scale preserves decimal places
964+
query RT
965+
SELECT power(4::decimal(38, 5), -1), arrow_typeof(power(4::decimal(38, 5), -1));
966+
----
967+
0.25 Decimal128(38, 5)
959968

960969
# Expected to have `16 Decimal128(38, 0)`
961970
# Due to type coericion, it becomes Float -> Float -> Float
@@ -975,20 +984,23 @@ SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0));
975984
----
976985
39 Decimal128(2, 1)
977986

978-
query error Compute error: Cannot use non-integer exp
987+
# Non-integer exponent now works (fallback to f64)
988+
query RT
979989
SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2));
990+
----
991+
46.9 Decimal128(2, 1)
980992

981-
query error Compute error: Cannot use non-integer exp: NaN
993+
query error Compute error: Cannot use non-finite exp: NaN
982994
SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64'))
983995

984-
query error Compute error: Cannot use non-integer exp: inf
996+
query error Compute error: Cannot use non-finite exp: inf
985997
SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64'))
986998

987-
# Floating above u32::max
988-
query error Compute error: Cannot use non-integer exp
999+
# Floating above u32::max now works (fallback to f64, returns infinity which is an error)
1000+
query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not finite
9891001
SELECT power(2::decimal(38, 0), 5000000000.1)
9901002

991-
# Integer Above u32::max
1003+
# Integer Above u32::max - still goes through integer path which fails
9921004
query error Arrow error: Arithmetic overflow: Unsupported exp value
9931005
SELECT power(2::decimal(38, 0), 5000000000)
9941006

0 commit comments

Comments
 (0)