@@ -21,10 +21,15 @@ use crate::function::error_utils::{
2121use arrow:: array:: * ;
2222use arrow:: datatypes:: DataType ;
2323use arrow:: datatypes:: * ;
24+ use arrow:: error:: ArrowError ;
2425use datafusion_common:: { internal_err, DataFusionError , Result , ScalarValue } ;
2526use datafusion_expr:: {
2627 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
2728} ;
29+ use datafusion_functions:: {
30+ downcast_named_arg, make_abs_function, make_decimal_abs_function,
31+ make_try_abs_function,
32+ } ;
2833use std:: any:: Any ;
2934use std:: sync:: Arc ;
3035
@@ -113,36 +118,8 @@ impl ScalarUDFImpl for SparkAbs {
113118 }
114119}
115120
116- macro_rules! legacy_compute_op {
117- ( $ARRAY: expr, $FUNC: ident, $TYPE: ident, $RESULT: ident) => { {
118- let array = $ARRAY. as_any( ) . downcast_ref:: <$TYPE>( ) . unwrap( ) ;
119- let res: $RESULT = arrow:: compute:: kernels:: arity:: unary( array, |x| x. $FUNC( ) ) ;
120- res
121- } } ;
122- }
123-
124- macro_rules! ansi_compute_op {
125- ( $ARRAY: expr, $FUNC: ident, $TYPE: ident, $RESULT: ident, $MIN: expr, $FROM_TYPE: expr) => { {
126- let array = $ARRAY. as_any( ) . downcast_ref:: <$TYPE>( ) . unwrap( ) ;
127- match arrow:: compute:: kernels:: arity:: try_unary( array, |x| {
128- if x == $MIN {
129- Err ( arrow:: error:: ArrowError :: ArithmeticOverflow (
130- $FROM_TYPE. to_string( ) ,
131- ) )
132- } else {
133- Ok ( x. $FUNC( ) )
134- }
135- } ) {
136- Ok ( res) => Ok ( ColumnarValue :: Array ( Arc :: <PrimitiveArray <$RESULT>>:: new(
137- res,
138- ) ) ) ,
139- Err ( _) => Err ( arithmetic_overflow_error( $FROM_TYPE) ) ,
140- }
141- } } ;
142- }
143-
144121fn arithmetic_overflow_error ( from_type : & str ) -> DataFusionError {
145- DataFusionError :: Execution ( format ! ( "arithmetic overflow from {from_type}" ) )
122+ DataFusionError :: Execution ( format ! ( "overflow on abs {from_type}" ) )
146123}
147124
148125pub fn spark_abs ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue , DataFusionError > {
@@ -175,171 +152,83 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
175152 | DataType :: UInt64 => Ok ( args[ 0 ] . clone ( ) ) ,
176153 DataType :: Int8 => {
177154 if !fail_on_error {
178- let result =
179- legacy_compute_op ! ( array, wrapping_abs, Int8Array , Int8Array ) ;
180- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
155+ let abs_fun = make_decimal_abs_function ! ( Int8Array ) ;
156+ abs_fun ( array) . map ( ColumnarValue :: Array )
181157 } else {
182- ansi_compute_op ! ( array, abs, Int8Array , Int8Type , i8 :: MIN , "Int8" )
158+ let abs_fun = make_try_abs_function ! ( Int8Array ) ;
159+ abs_fun ( array) . map ( ColumnarValue :: Array )
183160 }
184161 }
185162 DataType :: Int16 => {
186163 if !fail_on_error {
187- let result =
188- legacy_compute_op ! ( array, wrapping_abs, Int16Array , Int16Array ) ;
189- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
164+ let abs_fun = make_decimal_abs_function ! ( Int16Array ) ;
165+ abs_fun ( array) . map ( ColumnarValue :: Array )
190166 } else {
191- ansi_compute_op ! ( array, abs, Int16Array , Int16Type , i16 :: MIN , "Int16" )
167+ let abs_fun = make_try_abs_function ! ( Int16Array ) ;
168+ abs_fun ( array) . map ( ColumnarValue :: Array )
192169 }
193170 }
194171 DataType :: Int32 => {
195172 if !fail_on_error {
196- let result =
197- legacy_compute_op ! ( array, wrapping_abs, Int32Array , Int32Array ) ;
198- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
173+ let abs_fun = make_decimal_abs_function ! ( Int32Array ) ;
174+ abs_fun ( array) . map ( ColumnarValue :: Array )
199175 } else {
200- ansi_compute_op ! ( array, abs, Int32Array , Int32Type , i32 :: MIN , "Int32" )
176+ let abs_fun = make_try_abs_function ! ( Int32Array ) ;
177+ abs_fun ( array) . map ( ColumnarValue :: Array )
201178 }
202179 }
203180 DataType :: Int64 => {
204181 if !fail_on_error {
205- let result =
206- legacy_compute_op ! ( array, wrapping_abs, Int64Array , Int64Array ) ;
207- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
182+ let abs_fun = make_decimal_abs_function ! ( Int64Array ) ;
183+ abs_fun ( array) . map ( ColumnarValue :: Array )
208184 } else {
209- ansi_compute_op ! ( array, abs, Int64Array , Int64Type , i64 :: MIN , "Int64" )
185+ let abs_fun = make_try_abs_function ! ( Int64Array ) ;
186+ abs_fun ( array) . map ( ColumnarValue :: Array )
210187 }
211188 }
212189 DataType :: Float32 => {
213- let result = legacy_compute_op ! ( array , abs , Float32Array , Float32Array ) ;
214- Ok ( ColumnarValue :: Array ( Arc :: new ( result ) ) )
190+ let abs_fun = make_abs_function ! ( Float32Array ) ;
191+ abs_fun ( array ) . map ( ColumnarValue :: Array )
215192 }
216193 DataType :: Float64 => {
217- let result = legacy_compute_op ! ( array , abs , Float64Array , Float64Array ) ;
218- Ok ( ColumnarValue :: Array ( Arc :: new ( result ) ) )
194+ let abs_fun = make_abs_function ! ( Float64Array ) ;
195+ abs_fun ( array ) . map ( ColumnarValue :: Array )
219196 }
220- DataType :: Decimal128 ( precision , scale ) => {
197+ DataType :: Decimal128 ( _ , _ ) => {
221198 if !fail_on_error {
222- let result = legacy_compute_op ! (
223- array,
224- wrapping_abs,
225- Decimal128Array ,
226- Decimal128Array
227- ) ;
228- let result =
229- result. with_data_type ( DataType :: Decimal128 ( * precision, * scale) ) ;
230- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
199+ let abs_fun = make_decimal_abs_function ! ( Decimal128Array ) ;
200+ abs_fun ( array) . map ( ColumnarValue :: Array )
231201 } else {
232- // Need to pass precision and scale from input, so not using ansi_compute_op
233- let input = array. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) ;
234- match input {
235- Some ( i) => {
236- match arrow:: compute:: kernels:: arity:: try_unary ( i, |x| {
237- if x == i128:: MIN {
238- Err ( arrow:: error:: ArrowError :: ArithmeticOverflow (
239- "Decimal128" . to_string ( ) ,
240- ) )
241- } else {
242- Ok ( x. abs ( ) )
243- }
244- } ) {
245- Ok ( res) => Ok ( ColumnarValue :: Array ( Arc :: <
246- PrimitiveArray < Decimal128Type > ,
247- > :: new (
248- res. with_data_type ( DataType :: Decimal128 (
249- * precision, * scale,
250- ) ) ,
251- ) ) ) ,
252- Err ( _) => Err ( arithmetic_overflow_error ( "Decimal128" ) ) ,
253- }
254- }
255- _ => Err ( DataFusionError :: Internal (
256- "Invalid data type" . to_string ( ) ,
257- ) ) ,
258- }
202+ let abs_fun = make_try_abs_function ! ( Decimal128Array ) ;
203+ abs_fun ( array) . map ( ColumnarValue :: Array )
259204 }
260205 }
261- DataType :: Decimal256 ( precision , scale ) => {
206+ DataType :: Decimal256 ( _ , _ ) => {
262207 if !fail_on_error {
263- let result = legacy_compute_op ! (
264- array,
265- wrapping_abs,
266- Decimal256Array ,
267- Decimal256Array
268- ) ;
269- let result =
270- result. with_data_type ( DataType :: Decimal256 ( * precision, * scale) ) ;
271- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
208+ let abs_fun = make_decimal_abs_function ! ( Decimal256Array ) ;
209+ abs_fun ( array) . map ( ColumnarValue :: Array )
272210 } else {
273- // Need to pass precision and scale from input, so not using ansi_compute_op
274- let input = array. as_any ( ) . downcast_ref :: < Decimal256Array > ( ) ;
275- match input {
276- Some ( i) => {
277- match arrow:: compute:: kernels:: arity:: try_unary ( i, |x| {
278- if x == i256:: MIN {
279- Err ( arrow:: error:: ArrowError :: ArithmeticOverflow (
280- "Decimal256" . to_string ( ) ,
281- ) )
282- } else {
283- Ok ( x. wrapping_abs ( ) ) // i256 doesn't define abs() method
284- }
285- } ) {
286- Ok ( res) => Ok ( ColumnarValue :: Array ( Arc :: <
287- PrimitiveArray < Decimal256Type > ,
288- > :: new (
289- res. with_data_type ( DataType :: Decimal256 (
290- * precision, * scale,
291- ) ) ,
292- ) ) ) ,
293- Err ( _) => Err ( arithmetic_overflow_error ( "Decimal256" ) ) ,
294- }
295- }
296- _ => Err ( DataFusionError :: Internal (
297- "Invalid data type" . to_string ( ) ,
298- ) ) ,
299- }
211+ let abs_fun = make_try_abs_function ! ( Decimal256Array ) ;
212+ abs_fun ( array) . map ( ColumnarValue :: Array )
300213 }
301214 }
302215 DataType :: Interval ( unit) => match unit {
303216 IntervalUnit :: YearMonth => {
304217 if !fail_on_error {
305- let result = legacy_compute_op ! (
306- array,
307- wrapping_abs,
308- IntervalYearMonthArray ,
309- IntervalYearMonthArray
310- ) ;
311- let result = result. with_data_type ( DataType :: Interval ( * unit) ) ;
312- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
218+ let abs_fun = make_decimal_abs_function ! ( IntervalYearMonthArray ) ;
219+ abs_fun ( array) . map ( ColumnarValue :: Array )
313220 } else {
314- ansi_compute_op ! (
315- array,
316- abs,
317- IntervalYearMonthArray ,
318- IntervalYearMonthType ,
319- i32 :: MIN ,
320- "IntervalYearMonth"
321- )
221+ let abs_fun = make_try_abs_function ! ( IntervalYearMonthArray ) ;
222+ abs_fun ( array) . map ( ColumnarValue :: Array )
322223 }
323224 }
324225 IntervalUnit :: DayTime => {
325226 if !fail_on_error {
326- let result = legacy_compute_op ! (
327- array,
328- wrapping_abs,
329- IntervalDayTimeArray ,
330- IntervalDayTimeArray
331- ) ;
332- let result = result. with_data_type ( DataType :: Interval ( * unit) ) ;
333- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
227+ let abs_fun = make_decimal_abs_function ! ( IntervalDayTimeArray ) ;
228+ abs_fun ( array) . map ( ColumnarValue :: Array )
334229 } else {
335- ansi_compute_op ! (
336- array,
337- wrapping_abs,
338- IntervalDayTimeArray ,
339- IntervalDayTimeType ,
340- IntervalDayTime :: MIN ,
341- "IntervalDayTime"
342- )
230+ let abs_fun = make_try_abs_function ! ( IntervalDayTimeArray ) ;
231+ abs_fun ( array) . map ( ColumnarValue :: Array )
343232 }
344233 }
345234 IntervalUnit :: MonthDayNano => internal_err ! (
@@ -630,7 +519,7 @@ mod tests {
630519 match spark_abs( & [ args, fail_on_error] ) {
631520 Err ( e) => {
632521 assert!(
633- e. to_string( ) . contains( "arithmetic overflow" ) ,
522+ e. to_string( ) . contains( "overflow on abs " ) ,
634523 "Error message did not match. Actual message: {e}"
635524 ) ;
636525 }
@@ -654,7 +543,7 @@ mod tests {
654543 match spark_abs( & [ args, fail_on_error] ) {
655544 Err ( e) => {
656545 assert!(
657- e. to_string( ) . contains( "arithmetic overflow" ) ,
546+ e. to_string( ) . contains( "overflow on abs " ) ,
658547 "Error message did not match. Actual message: {e}"
659548 ) ;
660549 }
@@ -858,7 +747,7 @@ mod tests {
858747 match spark_abs( & [ args, fail_on_error] ) {
859748 Err ( e) => {
860749 assert!(
861- e. to_string( ) . contains( "arithmetic overflow" ) ,
750+ e. to_string( ) . contains( "overflow on abs " ) ,
862751 "Error message did not match. Actual message: {e}"
863752 ) ;
864753 }
0 commit comments