Skip to content

Commit 4dda5d8

Browse files
committed
fix: Used NumCast rather than manual
1 parent e3807b4 commit 4dda5d8

File tree

1 file changed

+107
-70
lines changed

1 file changed

+107
-70
lines changed

datafusion/functions/src/math/power.rs

Lines changed: 107 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
2424
use arrow::array::{Array, ArrayRef};
2525
use arrow::datatypes::{
2626
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
27-
Decimal256Type, Float64Type, Int64Type,
27+
Decimal256Type, Float64Type, Int64Type, i256,
2828
};
2929
use arrow::error::ArrowError;
3030
use datafusion_common::types::{NativeType, logical_float64, logical_int64};
@@ -115,7 +115,7 @@ impl PowerFunc {
115115
/// Scale it back to 1: 390625 / 10^4 = 39
116116
fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
117117
where
118-
T: From<i32> + ArrowNativeTypeOp + ToPrimitive,
118+
T: From<i32> + ArrowNativeTypeOp + ToPrimitive + num_traits::NumCast,
119119
{
120120
if exp < 0 {
121121
return pow_decimal_float(base, scale, exp as f64);
@@ -128,7 +128,7 @@ where
128128
})?;
129129
if exp == 0 {
130130
// Edge case to provide 1 as result (10^s with scale)
131-
let result: T = T::from(10).pow_checked(scale).map_err(|_| {
131+
let result: T = <T as From<i32>>::from(10).pow_checked(scale).map_err(|_| {
132132
ArrowError::ArithmeticOverflow(format!(
133133
"Cannot make unscale factor for {scale} and {exp}"
134134
))
@@ -141,18 +141,20 @@ where
141141
let powered: T = base.pow_checked(exp).map_err(|_| {
142142
ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
143143
})?;
144-
let unscale_factor: T = T::from(10).pow_checked(scale * (exp - 1)).map_err(|_| {
145-
ArrowError::ArithmeticOverflow(format!(
146-
"Cannot make unscale factor for {scale} and {exp}"
147-
))
148-
})?;
144+
let unscale_factor: T = <T as From<i32>>::from(10)
145+
.pow_checked(scale * (exp - 1))
146+
.map_err(|_| {
147+
ArrowError::ArithmeticOverflow(format!(
148+
"Cannot make unscale factor for {scale} and {exp}"
149+
))
150+
})?;
149151

150152
powered.div_checked(unscale_factor)
151153
}
152154

153155
fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
154156
where
155-
T: From<i32> + ArrowNativeTypeOp + ToPrimitive,
157+
T: From<i32> + ArrowNativeTypeOp + ToPrimitive + num_traits::NumCast,
156158
{
157159
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
158160
return pow_decimal_int(base, scale, exp as i64);
@@ -167,23 +169,65 @@ where
167169
pow_decimal_float_fallback(base, scale, exp)
168170
}
169171

170-
/// Fallback implementation using f64 for negative or non-integer exponents.
171-
/// This handles cases that cannot be computed using integer arithmetic.
172-
fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
173-
where
174-
T: From<i32> + ArrowNativeTypeOp + ToPrimitive,
175-
{
176-
if scale < 0 {
177-
return Err(ArrowError::NotYetImplemented(format!(
178-
"Negative scale is not yet supported: {scale}"
172+
/// Decimal256 specialized float exponent version.
173+
fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, ArrowError> {
174+
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
175+
return pow_decimal256_int(base, scale, exp as i64);
176+
}
177+
178+
if !exp.is_finite() {
179+
return Err(ArrowError::ComputeError(format!(
180+
"Cannot use non-finite exp: {exp}"
179181
)));
180182
}
181183

182-
let scale_factor = 10f64.powi(scale as i32);
183-
let base_f64 = base.to_f64().ok_or_else(|| {
184-
ArrowError::ComputeError("Cannot convert base to f64".to_string())
185-
})? / scale_factor;
184+
pow_decimal256_float_fallback(base, scale, exp)
185+
}
186+
187+
/// Decimal256 specialized integer exponent version.
188+
fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, ArrowError> {
189+
if exp < 0 {
190+
return pow_decimal256_float(base, scale, exp as f64);
191+
}
186192

193+
let scale: u32 = scale.try_into().map_err(|_| {
194+
ArrowError::NotYetImplemented(format!(
195+
"Negative scale is not yet supported value: {scale}"
196+
))
197+
})?;
198+
if exp == 0 {
199+
let result: i256 = i256::from_i128(10).pow_checked(scale).map_err(|_| {
200+
ArrowError::ArithmeticOverflow(format!(
201+
"Cannot make unscale factor for {scale} and {exp}"
202+
))
203+
})?;
204+
return Ok(result);
205+
}
206+
let exp: u32 = exp.try_into().map_err(|_| {
207+
ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
208+
})?;
209+
let powered: i256 = base.pow_checked(exp).map_err(|_| {
210+
ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
211+
})?;
212+
let unscale_factor: i256 = i256::from_i128(10)
213+
.pow_checked(scale * (exp - 1))
214+
.map_err(|_| {
215+
ArrowError::ArithmeticOverflow(format!(
216+
"Cannot make unscale factor for {scale} and {exp}"
217+
))
218+
})?;
219+
220+
powered.div_checked(unscale_factor)
221+
}
222+
223+
/// Compute the f64 power result and scale it back.
224+
/// Returns the rounded i128 result for conversion to target type.
225+
#[inline]
226+
fn compute_pow_f64_result(
227+
base_f64: f64,
228+
scale: i8,
229+
exp: f64,
230+
) -> Result<i128, ArrowError> {
187231
let result_f64 = base_f64.powf(exp);
188232

189233
if !result_f64.is_finite() {
@@ -192,6 +236,7 @@ where
192236
)));
193237
}
194238

239+
let scale_factor = 10f64.powi(scale as i32);
195240
let result_scaled = result_f64 * scale_factor;
196241
let result_rounded = result_scaled.round();
197242

@@ -201,64 +246,56 @@ where
201246
)));
202247
}
203248

204-
decimal_from_i128::<T>(result_rounded as i128)
249+
Ok(result_rounded as i128)
205250
}
206251

207-
fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
252+
/// Fallback implementation using f64 for negative or non-integer exponents.
253+
/// This handles cases that cannot be computed using integer arithmetic.
254+
fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
208255
where
209-
T: From<i32> + ArrowNativeTypeOp,
256+
T: ToPrimitive + num_traits::NumCast + Copy,
210257
{
211-
if value == 0 {
212-
return Ok(T::from(0));
213-
}
214-
215-
if value >= i32::MIN as i128 && value <= i32::MAX as i128 {
216-
return Ok(T::from(value as i32));
258+
if scale < 0 {
259+
return Err(ArrowError::NotYetImplemented(format!(
260+
"Negative scale is not yet supported: {scale}"
261+
)));
217262
}
218263

219-
let is_negative = value < 0;
220-
let abs_value = value.unsigned_abs();
221-
222-
let billion = 1_000_000_000u128;
223-
let mut result = T::from(0);
224-
let mut multiplier = T::from(1);
225-
let billion_t = T::from(1_000_000_000);
264+
let scale_factor = 10f64.powi(scale as i32);
265+
let base_f64 = base.to_f64().ok_or_else(|| {
266+
ArrowError::ComputeError("Cannot convert base to f64".to_string())
267+
})? / scale_factor;
226268

227-
let mut remaining = abs_value;
228-
while remaining > 0 {
229-
let chunk = (remaining % billion) as i32;
230-
remaining /= billion;
269+
let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
231270

232-
let chunk_value = T::from(chunk).mul_checked(multiplier).map_err(|_| {
233-
ArrowError::ArithmeticOverflow(format!(
234-
"Overflow while converting {value} to decimal type"
235-
))
236-
})?;
237-
238-
result = result.add_checked(chunk_value).map_err(|_| {
239-
ArrowError::ArithmeticOverflow(format!(
240-
"Overflow while converting {value} to decimal type"
241-
))
242-
})?;
271+
num_traits::NumCast::from(result_i128).ok_or_else(|| {
272+
ArrowError::ArithmeticOverflow(format!(
273+
"Value {result_i128} is too large for the target decimal type"
274+
))
275+
})
276+
}
243277

244-
if remaining > 0 {
245-
multiplier = multiplier.mul_checked(billion_t).map_err(|_| {
246-
ArrowError::ArithmeticOverflow(format!(
247-
"Overflow while converting {value} to decimal type"
248-
))
249-
})?;
250-
}
278+
/// Fallback implementation for Decimal256.
279+
fn pow_decimal256_float_fallback(
280+
base: i256,
281+
scale: i8,
282+
exp: f64,
283+
) -> Result<i256, ArrowError> {
284+
if scale < 0 {
285+
return Err(ArrowError::NotYetImplemented(format!(
286+
"Negative scale is not yet supported: {scale}"
287+
)));
251288
}
252289

253-
if is_negative {
254-
result = T::from(0).sub_checked(result).map_err(|_| {
255-
ArrowError::ArithmeticOverflow(format!(
256-
"Overflow while negating {value} in decimal type"
257-
))
258-
})?;
259-
}
290+
let scale_factor = 10f64.powi(scale as i32);
291+
let base_f64 = base.to_f64().ok_or_else(|| {
292+
ArrowError::ComputeError("Cannot convert base to f64".to_string())
293+
})? / scale_factor;
294+
295+
let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
260296

261-
Ok(result)
297+
// i256 can be constructed from i128 directly
298+
Ok(i256::from_i128(result_i128))
262299
}
263300

264301
impl ScalarUDFImpl for PowerFunc {
@@ -381,7 +418,7 @@ impl ScalarUDFImpl for PowerFunc {
381418
>(
382419
&base,
383420
exponent,
384-
|b, e| pow_decimal_int(b, *scale, e),
421+
|b, e| pow_decimal256_int(b, *scale, e),
385422
*precision,
386423
*scale,
387424
)?
@@ -395,7 +432,7 @@ impl ScalarUDFImpl for PowerFunc {
395432
>(
396433
&base,
397434
exponent,
398-
|b, e| pow_decimal_float(b, *scale, e),
435+
|b, e| pow_decimal256_float(b, *scale, e),
399436
*precision,
400437
*scale,
401438
)?

0 commit comments

Comments
 (0)