Skip to content

Commit 098f1e4

Browse files
committed
Reuse marcos from DF's abs function
1 parent 4434f97 commit 098f1e4

File tree

2 files changed

+51
-159
lines changed
  • datafusion

2 files changed

+51
-159
lines changed

datafusion/functions/src/math/abs.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use num_traits::sign::Signed;
3939

4040
type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
4141

42+
#[macro_export]
4243
macro_rules! make_abs_function {
4344
($ARRAY_TYPE:ident) => {{
4445
|input: &ArrayRef| {
@@ -49,14 +50,15 @@ macro_rules! make_abs_function {
4950
}};
5051
}
5152

53+
#[macro_export]
5254
macro_rules! make_try_abs_function {
5355
($ARRAY_TYPE:ident) => {{
5456
|input: &ArrayRef| {
5557
let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
5658
let res: $ARRAY_TYPE = array.try_unary(|x| {
5759
x.checked_abs().ok_or_else(|| {
5860
ArrowError::ComputeError(format!(
59-
"{} overflow on abs({})",
61+
"{} overflow on abs({:?})",
6062
stringify!($ARRAY_TYPE),
6163
x
6264
))
@@ -67,6 +69,7 @@ macro_rules! make_try_abs_function {
6769
}};
6870
}
6971

72+
#[macro_export]
7073
macro_rules! make_decimal_abs_function {
7174
($ARRAY_TYPE:ident) => {{
7275
|input: &ArrayRef| {

datafusion/spark/src/function/math/abs.rs

Lines changed: 47 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ use crate::function::error_utils::{
2121
use arrow::array::*;
2222
use arrow::datatypes::DataType;
2323
use arrow::datatypes::*;
24+
use arrow::error::ArrowError;
2425
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
2526
use datafusion_expr::{
2627
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2728
};
29+
use datafusion_functions::{
30+
downcast_named_arg, make_abs_function, make_decimal_abs_function,
31+
make_try_abs_function,
32+
};
2833
use std::any::Any;
2934
use std::sync::Arc;
3035

@@ -113,36 +118,8 @@ impl ScalarUDFImpl for SparkAbs {
113118
}
114119
}
115120

116-
macro_rules! legacy_compute_op {
117-
($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
118-
let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
119-
let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC());
120-
res
121-
}};
122-
}
123-
124-
macro_rules! ansi_compute_op {
125-
($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $MIN:expr, $FROM_TYPE:expr) => {{
126-
let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
127-
match arrow::compute::kernels::arity::try_unary(array, |x| {
128-
if x == $MIN {
129-
Err(arrow::error::ArrowError::ArithmeticOverflow(
130-
$FROM_TYPE.to_string(),
131-
))
132-
} else {
133-
Ok(x.$FUNC())
134-
}
135-
}) {
136-
Ok(res) => Ok(ColumnarValue::Array(Arc::<PrimitiveArray<$RESULT>>::new(
137-
res,
138-
))),
139-
Err(_) => Err(arithmetic_overflow_error($FROM_TYPE)),
140-
}
141-
}};
142-
}
143-
144121
fn arithmetic_overflow_error(from_type: &str) -> DataFusionError {
145-
DataFusionError::Execution(format!("arithmetic overflow from {from_type}"))
122+
DataFusionError::Execution(format!("overflow on abs {from_type}"))
146123
}
147124

148125
pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
@@ -175,171 +152,83 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
175152
| DataType::UInt64 => Ok(args[0].clone()),
176153
DataType::Int8 => {
177154
if !fail_on_error {
178-
let result =
179-
legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array);
180-
Ok(ColumnarValue::Array(Arc::new(result)))
155+
let abs_fun = make_decimal_abs_function!(Int8Array);
156+
abs_fun(array).map(ColumnarValue::Array)
181157
} else {
182-
ansi_compute_op!(array, abs, Int8Array, Int8Type, i8::MIN, "Int8")
158+
let abs_fun = make_try_abs_function!(Int8Array);
159+
abs_fun(array).map(ColumnarValue::Array)
183160
}
184161
}
185162
DataType::Int16 => {
186163
if !fail_on_error {
187-
let result =
188-
legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array);
189-
Ok(ColumnarValue::Array(Arc::new(result)))
164+
let abs_fun = make_decimal_abs_function!(Int16Array);
165+
abs_fun(array).map(ColumnarValue::Array)
190166
} else {
191-
ansi_compute_op!(array, abs, Int16Array, Int16Type, i16::MIN, "Int16")
167+
let abs_fun = make_try_abs_function!(Int16Array);
168+
abs_fun(array).map(ColumnarValue::Array)
192169
}
193170
}
194171
DataType::Int32 => {
195172
if !fail_on_error {
196-
let result =
197-
legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array);
198-
Ok(ColumnarValue::Array(Arc::new(result)))
173+
let abs_fun = make_decimal_abs_function!(Int32Array);
174+
abs_fun(array).map(ColumnarValue::Array)
199175
} else {
200-
ansi_compute_op!(array, abs, Int32Array, Int32Type, i32::MIN, "Int32")
176+
let abs_fun = make_try_abs_function!(Int32Array);
177+
abs_fun(array).map(ColumnarValue::Array)
201178
}
202179
}
203180
DataType::Int64 => {
204181
if !fail_on_error {
205-
let result =
206-
legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array);
207-
Ok(ColumnarValue::Array(Arc::new(result)))
182+
let abs_fun = make_decimal_abs_function!(Int64Array);
183+
abs_fun(array).map(ColumnarValue::Array)
208184
} else {
209-
ansi_compute_op!(array, abs, Int64Array, Int64Type, i64::MIN, "Int64")
185+
let abs_fun = make_try_abs_function!(Int64Array);
186+
abs_fun(array).map(ColumnarValue::Array)
210187
}
211188
}
212189
DataType::Float32 => {
213-
let result = legacy_compute_op!(array, abs, Float32Array, Float32Array);
214-
Ok(ColumnarValue::Array(Arc::new(result)))
190+
let abs_fun = make_abs_function!(Float32Array);
191+
abs_fun(array).map(ColumnarValue::Array)
215192
}
216193
DataType::Float64 => {
217-
let result = legacy_compute_op!(array, abs, Float64Array, Float64Array);
218-
Ok(ColumnarValue::Array(Arc::new(result)))
194+
let abs_fun = make_abs_function!(Float64Array);
195+
abs_fun(array).map(ColumnarValue::Array)
219196
}
220-
DataType::Decimal128(precision, scale) => {
197+
DataType::Decimal128(_, _) => {
221198
if !fail_on_error {
222-
let result = legacy_compute_op!(
223-
array,
224-
wrapping_abs,
225-
Decimal128Array,
226-
Decimal128Array
227-
);
228-
let result =
229-
result.with_data_type(DataType::Decimal128(*precision, *scale));
230-
Ok(ColumnarValue::Array(Arc::new(result)))
199+
let abs_fun = make_decimal_abs_function!(Decimal128Array);
200+
abs_fun(array).map(ColumnarValue::Array)
231201
} else {
232-
// Need to pass precision and scale from input, so not using ansi_compute_op
233-
let input = array.as_any().downcast_ref::<Decimal128Array>();
234-
match input {
235-
Some(i) => {
236-
match arrow::compute::kernels::arity::try_unary(i, |x| {
237-
if x == i128::MIN {
238-
Err(arrow::error::ArrowError::ArithmeticOverflow(
239-
"Decimal128".to_string(),
240-
))
241-
} else {
242-
Ok(x.abs())
243-
}
244-
}) {
245-
Ok(res) => Ok(ColumnarValue::Array(Arc::<
246-
PrimitiveArray<Decimal128Type>,
247-
>::new(
248-
res.with_data_type(DataType::Decimal128(
249-
*precision, *scale,
250-
)),
251-
))),
252-
Err(_) => Err(arithmetic_overflow_error("Decimal128")),
253-
}
254-
}
255-
_ => Err(DataFusionError::Internal(
256-
"Invalid data type".to_string(),
257-
)),
258-
}
202+
let abs_fun = make_try_abs_function!(Decimal128Array);
203+
abs_fun(array).map(ColumnarValue::Array)
259204
}
260205
}
261-
DataType::Decimal256(precision, scale) => {
206+
DataType::Decimal256(_, _) => {
262207
if !fail_on_error {
263-
let result = legacy_compute_op!(
264-
array,
265-
wrapping_abs,
266-
Decimal256Array,
267-
Decimal256Array
268-
);
269-
let result =
270-
result.with_data_type(DataType::Decimal256(*precision, *scale));
271-
Ok(ColumnarValue::Array(Arc::new(result)))
208+
let abs_fun = make_decimal_abs_function!(Decimal256Array);
209+
abs_fun(array).map(ColumnarValue::Array)
272210
} else {
273-
// Need to pass precision and scale from input, so not using ansi_compute_op
274-
let input = array.as_any().downcast_ref::<Decimal256Array>();
275-
match input {
276-
Some(i) => {
277-
match arrow::compute::kernels::arity::try_unary(i, |x| {
278-
if x == i256::MIN {
279-
Err(arrow::error::ArrowError::ArithmeticOverflow(
280-
"Decimal256".to_string(),
281-
))
282-
} else {
283-
Ok(x.wrapping_abs()) // i256 doesn't define abs() method
284-
}
285-
}) {
286-
Ok(res) => Ok(ColumnarValue::Array(Arc::<
287-
PrimitiveArray<Decimal256Type>,
288-
>::new(
289-
res.with_data_type(DataType::Decimal256(
290-
*precision, *scale,
291-
)),
292-
))),
293-
Err(_) => Err(arithmetic_overflow_error("Decimal256")),
294-
}
295-
}
296-
_ => Err(DataFusionError::Internal(
297-
"Invalid data type".to_string(),
298-
)),
299-
}
211+
let abs_fun = make_try_abs_function!(Decimal256Array);
212+
abs_fun(array).map(ColumnarValue::Array)
300213
}
301214
}
302215
DataType::Interval(unit) => match unit {
303216
IntervalUnit::YearMonth => {
304217
if !fail_on_error {
305-
let result = legacy_compute_op!(
306-
array,
307-
wrapping_abs,
308-
IntervalYearMonthArray,
309-
IntervalYearMonthArray
310-
);
311-
let result = result.with_data_type(DataType::Interval(*unit));
312-
Ok(ColumnarValue::Array(Arc::new(result)))
218+
let abs_fun = make_decimal_abs_function!(IntervalYearMonthArray);
219+
abs_fun(array).map(ColumnarValue::Array)
313220
} else {
314-
ansi_compute_op!(
315-
array,
316-
abs,
317-
IntervalYearMonthArray,
318-
IntervalYearMonthType,
319-
i32::MIN,
320-
"IntervalYearMonth"
321-
)
221+
let abs_fun = make_try_abs_function!(IntervalYearMonthArray);
222+
abs_fun(array).map(ColumnarValue::Array)
322223
}
323224
}
324225
IntervalUnit::DayTime => {
325226
if !fail_on_error {
326-
let result = legacy_compute_op!(
327-
array,
328-
wrapping_abs,
329-
IntervalDayTimeArray,
330-
IntervalDayTimeArray
331-
);
332-
let result = result.with_data_type(DataType::Interval(*unit));
333-
Ok(ColumnarValue::Array(Arc::new(result)))
227+
let abs_fun = make_decimal_abs_function!(IntervalDayTimeArray);
228+
abs_fun(array).map(ColumnarValue::Array)
334229
} else {
335-
ansi_compute_op!(
336-
array,
337-
wrapping_abs,
338-
IntervalDayTimeArray,
339-
IntervalDayTimeType,
340-
IntervalDayTime::MIN,
341-
"IntervalDayTime"
342-
)
230+
let abs_fun = make_try_abs_function!(IntervalDayTimeArray);
231+
abs_fun(array).map(ColumnarValue::Array)
343232
}
344233
}
345234
IntervalUnit::MonthDayNano => internal_err!(
@@ -630,7 +519,7 @@ mod tests {
630519
match spark_abs(&[args, fail_on_error]) {
631520
Err(e) => {
632521
assert!(
633-
e.to_string().contains("arithmetic overflow"),
522+
e.to_string().contains("overflow on abs"),
634523
"Error message did not match. Actual message: {e}"
635524
);
636525
}
@@ -654,7 +543,7 @@ mod tests {
654543
match spark_abs(&[args, fail_on_error]) {
655544
Err(e) => {
656545
assert!(
657-
e.to_string().contains("arithmetic overflow"),
546+
e.to_string().contains("overflow on abs"),
658547
"Error message did not match. Actual message: {e}"
659548
);
660549
}
@@ -858,7 +747,7 @@ mod tests {
858747
match spark_abs(&[args, fail_on_error]) {
859748
Err(e) => {
860749
assert!(
861-
e.to_string().contains("arithmetic overflow"),
750+
e.to_string().contains("overflow on abs"),
862751
"Error message did not match. Actual message: {e}"
863752
);
864753
}

0 commit comments

Comments
 (0)