diff --git a/Cargo.lock b/Cargo.lock index 1675f26e8a0f0..c4a742020a0ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2685,6 +2685,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-nested", "log", "percent-encoding", diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 198ba54adfa2a..9e4e140b8788b 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -19,19 +19,21 @@ use ahash::RandomState; use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray}; -use arrow::datatypes::Field; use arrow::datatypes::{ ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type, - Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, - DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef, + Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, Int64Type, TimeUnit, UInt64Type, }; +use arrow::error::ArrowError; use datafusion_common::types::{ NativeType, logical_float64, logical_int8, logical_int16, logical_int32, logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64, }; -use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_common::{ + HashMap, Result, ScalarValue, exec_err, internal_err, not_impl_err, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name}; use datafusion_expr::{ @@ -41,6 +43,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator; +use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; use datafusion_macros::user_doc; use std::any::Any; use std::mem::size_of_val; @@ -143,50 +146,64 @@ macro_rules! downcast_sum { #[derive(Debug, PartialEq, Eq, Hash)] pub struct Sum { signature: Signature, + // If true then returns null on overflows + try_sum_mode: bool, } impl Sum { + fn signature() -> Signature { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Decimal, + )]), + // Unsigned to u64 + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_uint64()), + vec![ + TypeSignatureClass::Native(logical_uint8()), + TypeSignatureClass::Native(logical_uint16()), + TypeSignatureClass::Native(logical_uint32()), + ], + NativeType::UInt64, + )]), + // Signed to i64 + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![ + TypeSignatureClass::Native(logical_int8()), + TypeSignatureClass::Native(logical_int16()), + TypeSignatureClass::Native(logical_int32()), + ], + NativeType::Int64, + )]), + // Floats to f64 + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Float], + NativeType::Float64, + )]), + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Duration, + )]), + ], + Volatility::Immutable, + ) + } + pub fn new() -> Self { Self { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - signature: Signature::one_of( - vec![ - TypeSignature::Coercible(vec![Coercion::new_exact( - TypeSignatureClass::Decimal, - )]), - // Unsigned to u64 - TypeSignature::Coercible(vec![Coercion::new_implicit( - TypeSignatureClass::Native(logical_uint64()), - vec![ - TypeSignatureClass::Native(logical_uint8()), - TypeSignatureClass::Native(logical_uint16()), - TypeSignatureClass::Native(logical_uint32()), - ], - NativeType::UInt64, - )]), - // Signed to i64 - TypeSignature::Coercible(vec![Coercion::new_implicit( - TypeSignatureClass::Native(logical_int64()), - vec![ - TypeSignatureClass::Native(logical_int8()), - TypeSignatureClass::Native(logical_int16()), - TypeSignatureClass::Native(logical_int32()), - ], - NativeType::Int64, - )]), - // Floats to f64 - TypeSignature::Coercible(vec![Coercion::new_implicit( - TypeSignatureClass::Native(logical_float64()), - vec![TypeSignatureClass::Float], - NativeType::Float64, - )]), - TypeSignature::Coercible(vec![Coercion::new_exact( - TypeSignatureClass::Duration, - )]), - ], - Volatility::Immutable, - ), + signature: Self::signature(), + try_sum_mode: false, + } + } + + pub fn try_sum() -> Self { + Self { + signature: Self::signature(), + try_sum_mode: true, } } } @@ -212,9 +229,7 @@ impl AggregateUDFImpl for Sum { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - DataType::Int64 => Ok(DataType::Int64), - DataType::UInt64 => Ok(DataType::UInt64), - DataType::Float64 => Ok(DataType::Float64), + DataType::Null => Ok(DataType::Float64), // In the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 DataType::Decimal32(precision, scale) => { @@ -233,56 +248,126 @@ impl AggregateUDFImpl for Sum { let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - other => { - exec_err!("[return_type] SUM not supported for {}", other) - } + dt => Ok(dt.clone()), } } fn accumulator(&self, args: AccumulatorArgs) -> Result> { - if args.is_distinct { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt))) - }; + if args.expr_fields[0].data_type() == &DataType::Null { + return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None)))); + } + match (args.is_distinct, self.try_sum_mode) { + (true, false) => { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt))) + }; + } + downcast_sum!(args, helper) } - downcast_sum!(args, helper) - } else { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) - }; + (false, false) => { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) + } + (false, true) => { + match args.return_type() { + DataType::UInt64 => Ok(Box::new( + TrySumAccumulator::::new(DataType::UInt64), + )), + DataType::Int64 => Ok(Box::new(TrySumAccumulator::::new( + DataType::Int64, + ))), + DataType::Float64 => Ok(Box::new( + TrySumAccumulator::::new(DataType::Float64), + )), + DataType::Duration(TimeUnit::Second) => { + Ok(Box::new(TrySumAccumulator::::new( + DataType::Duration(TimeUnit::Second), + ))) + } + DataType::Duration(TimeUnit::Millisecond) => { + Ok(Box::new(TrySumAccumulator::::new( + DataType::Duration(TimeUnit::Millisecond), + ))) + } + DataType::Duration(TimeUnit::Microsecond) => { + Ok(Box::new(TrySumAccumulator::::new( + DataType::Duration(TimeUnit::Microsecond), + ))) + } + DataType::Duration(TimeUnit::Nanosecond) => { + Ok(Box::new(TrySumAccumulator::::new( + DataType::Duration(TimeUnit::Nanosecond), + ))) + } + dt @ DataType::Decimal32(..) => Ok(Box::new( + TrySumDecimalAccumulator::::new(dt.clone()), + )), + dt @ DataType::Decimal64(..) => Ok(Box::new( + TrySumDecimalAccumulator::::new(dt.clone()), + )), + dt @ DataType::Decimal128(..) => Ok(Box::new( + TrySumDecimalAccumulator::::new(dt.clone()), + )), + dt @ DataType::Decimal256(..) => Ok(Box::new( + TrySumDecimalAccumulator::::new(dt.clone()), + )), + dt => internal_err!("Unsupported datatype for sum: {dt}"), + } + } + (true, true) => { + not_impl_err!("Try sum mode not supported for distinct sum accumulators") } - downcast_sum!(args, helper) } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - if args.is_distinct { - Ok(vec![ - Field::new_list( - format_state_name(args.name, "sum distinct"), - // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type().clone(), true), - false, + match (args.is_distinct, self.try_sum_mode) { + (true, false) => { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "sum distinct"), + // See COMMENTS.md to understand why nullable is set to true + Field::new_list_field(args.return_type().clone(), true), + false, + ) + .into(), + ]) + } + (false, false) => Ok(vec![ + Field::new( + format_state_name(args.name, "sum"), + args.return_type().clone(), + true, ) .into(), - ]) - } else { - Ok(vec![ + ]), + (false, true) => Ok(vec![ Field::new( format_state_name(args.name, "sum"), args.return_type().clone(), true, ) .into(), - ]) + Field::new( + format_state_name(args.name, "sum failed"), + DataType::Boolean, + false, + ) + .into(), + ]), + (true, true) => { + not_impl_err!("Try sum mode not supported for distinct sum accumulators") + } } } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - !args.is_distinct + !args.is_distinct && !self.try_sum_mode } fn create_groups_accumulator( @@ -304,8 +389,12 @@ impl AggregateUDFImpl for Sum { &self, args: AccumulatorArgs, ) -> Result> { + if self.try_sum_mode { + return not_impl_err!( + "Try sum mode not supported for sum sliding accumulators" + ); + } if args.is_distinct { - // distinct path: use our sliding‐window distinct‐sum macro_rules! helper_distinct { ($t:ty, $dt:expr) => { Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?)) @@ -313,7 +402,6 @@ impl AggregateUDFImpl for Sum { } downcast_sum!(args, helper_distinct) } else { - // non‐distinct path: existing sliding sum macro_rules! helper { ($t:ty, $dt:expr) => { Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) @@ -336,6 +424,10 @@ impl AggregateUDFImpl for Sum { } fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity { + // Can overflow into null + if self.try_sum_mode { + return SetMonotonicity::NotMonotonic; + } // `SUM` is only monotonically increasing when its input is unsigned. // TODO: Expand these utilizing statistics. match data_type { @@ -396,6 +488,161 @@ impl Accumulator for SumAccumulator { } } +#[derive(Debug, Eq, PartialEq)] +enum TrySumState { + Initial, + ValidSum(T), + Overflow, +} + +/// Will return `NULL` if at any point the sum overflows. +#[derive(Debug)] +struct TrySumAccumulator { + state: TrySumState, + data_type: DataType, +} + +impl TrySumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + state: TrySumState::Initial, + data_type, + } + } +} + +impl Accumulator for TrySumAccumulator { + fn state(&mut self) -> Result> { + match self.state { + TrySumState::Initial => Ok(vec![ + ScalarValue::try_new_null(&self.data_type)?, + ScalarValue::from(false), + ]), + TrySumState::ValidSum(sum) => Ok(vec![ + ScalarValue::new_primitive::(Some(sum), &self.data_type)?, + ScalarValue::from(false), + ]), + TrySumState::Overflow => Ok(vec![ + ScalarValue::try_new_null(&self.data_type)?, + ScalarValue::from(true), + ]), + } + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let v = match self.state { + TrySumState::Initial => T::Native::ZERO, + TrySumState::ValidSum(sum) => sum, + TrySumState::Overflow => return Ok(()), + }; + let values = values[0].as_primitive::(); + match arrow::compute::sum_checked(values) { + Ok(Some(x)) => match v.add_checked(x) { + Ok(sum) => { + self.state = TrySumState::ValidSum(sum); + } + Err(ArrowError::ArithmeticOverflow(_)) => { + self.state = TrySumState::Overflow; + } + Err(e) => { + return Err(e.into()); + } + }, + Ok(None) => (), + Err(ArrowError::ArithmeticOverflow(_)) => { + self.state = TrySumState::Overflow; + } + Err(e) => { + return Err(e.into()); + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let other_batch_failed = states[1].as_boolean().value(0); + if other_batch_failed { + self.state = TrySumState::Overflow; + return Ok(()); + } + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + match self.state { + TrySumState::Initial | TrySumState::Overflow => { + ScalarValue::try_new_null(&self.data_type) + } + TrySumState::ValidSum(sum) => { + ScalarValue::new_primitive::(Some(sum), &self.data_type) + } + } + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +// Only difference from TrySumAccumulator is that it verifies the resulting sum +// can fit within the decimals precision; if Rust had specialization we could unify +// the two types (╥﹏╥) +#[derive(Debug)] +struct TrySumDecimalAccumulator { + inner: TrySumAccumulator, +} + +impl TrySumDecimalAccumulator { + fn new(data_type: DataType) -> Self { + Self { + inner: TrySumAccumulator::new(data_type), + } + } + + fn validate_decimal(&mut self) { + // Check decimal precision overflow + let precision = match self.inner.data_type { + DataType::Decimal32(precision, _) + | DataType::Decimal64(precision, _) + | DataType::Decimal128(precision, _) + | DataType::Decimal256(precision, _) => precision, + _ => unreachable!(), + }; + if let TrySumState::ValidSum(sum) = self.inner.state + && !T::is_valid_decimal_precision(sum, precision) + { + self.inner.state = TrySumState::Overflow; + } + } +} + +impl Accumulator for TrySumDecimalAccumulator { + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.inner.update_batch(values)?; + self.validate_decimal(); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states)?; + self.validate_decimal(); + Ok(()) + } + + fn evaluate(&mut self) -> Result { + self.inner.evaluate() + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + /// This accumulator incrementally computes sums over a sliding window /// /// This is separate from [`SumAccumulator`] as requires additional state diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 09959db41fe60..e53ee354fc598 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -48,6 +48,7 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } percent-encoding = "2.3.2" diff --git a/datafusion/spark/src/function/aggregate/try_sum.rs b/datafusion/spark/src/function/aggregate/try_sum.rs index 6509cea26b716..3553f9ccfb197 100644 --- a/datafusion/spark/src/function/aggregate/try_sum.rs +++ b/datafusion/spark/src/function/aggregate/try_sum.rs @@ -15,22 +15,20 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, ArrowNumericType, AsArray, BooleanArray, PrimitiveArray}; -use arrow::datatypes::{ - DECIMAL128_MAX_PRECISION, DataType, Decimal128Type, Field, FieldRef, Float64Type, - Int64Type, -}; -use datafusion_common::{Result, ScalarValue, downcast_value, exec_err, not_impl_err}; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::Result; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion_functions_aggregate::sum::Sum; use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::mem::size_of_val; +use std::fmt::Debug; -#[derive(PartialEq, Eq, Hash)] +/// Thin wrapper over DataFusion native [`Sum`] which is configurable into a try +/// sum mode to return `null` on overflows. We need this thin wrapper to provide +/// the `try_sum` named function for use in Spark. +#[derive(PartialEq, Eq, Hash, Debug)] pub struct SparkTrySum { - signature: Signature, + inner: Sum, } impl Default for SparkTrySum { @@ -42,211 +40,11 @@ impl Default for SparkTrySum { impl SparkTrySum { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + inner: Sum::try_sum(), } } } -impl Debug for SparkTrySum { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SparkTrySum") - .field("signature", &self.signature) - .finish() - } -} - -/// Accumulator for try_sum that detects overflow -struct TrySumAccumulator { - sum: Option, - data_type: DataType, - failed: bool, - // Only used if data_type is Decimal128(p, s) - dec_precision: Option, -} - -impl Debug for TrySumAccumulator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "TrySumAccumulator({})", self.data_type) - } -} - -impl TrySumAccumulator { - fn new(data_type: DataType) -> Self { - let dec_precision = match &data_type { - DataType::Decimal128(p, _) => Some(*p), - _ => None, - }; - Self { - sum: None, - data_type, - failed: false, - dec_precision, - } - } -} - -impl Accumulator for TrySumAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - self.evaluate()?, - ScalarValue::Boolean(Some(self.failed)), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - update_batch_internal(self, values) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // Check if any partition has failed - if downcast_value!(states[1], BooleanArray) - .iter() - .flatten() - .any(|f| f) - { - self.failed = true; - return Ok(()); - } - - // Merge the sum values using the same logic as update_batch - update_batch_internal(self, states) - } - - fn evaluate(&mut self) -> Result { - evaluate_internal(self) - } - - fn size(&self) -> usize { - size_of_val(self) - } -} - -// Specialized implementations for update_batch for each type - -fn update_batch_internal( - acc: &mut TrySumAccumulator, - values: &[ArrayRef], -) -> Result<()> { - if values.is_empty() || acc.failed { - return Ok(()); - } - - let array: &PrimitiveArray = values[0].as_primitive::(); - - match acc.data_type { - DataType::Int64 => update_int64(acc, array), - DataType::Float64 => update_float64(acc, array), - DataType::Decimal128(_, _) => update_decimal128(acc, array), - _ => exec_err!( - "try_sum: unsupported type in update_batch: {:?}", - acc.data_type - ), - } -} - -fn update_int64( - acc: &mut TrySumAccumulator, - array: &PrimitiveArray, -) -> Result<()> { - for v in array.iter().flatten() { - // Cast to i64 for checked_add - let v_i64 = unsafe { std::mem::transmute_copy::(&v) }; - let sum_i64 = acc - .sum - .map(|s| unsafe { std::mem::transmute_copy::(&s) }); - - let new_sum = match sum_i64 { - None => v_i64, - Some(s) => match s.checked_add(v_i64) { - Some(result) => result, - None => { - acc.failed = true; - return Ok(()); - } - }, - }; - - acc.sum = Some(unsafe { std::mem::transmute_copy::(&new_sum) }); - } - Ok(()) -} - -fn update_float64( - acc: &mut TrySumAccumulator, - array: &PrimitiveArray, -) -> Result<()> { - for v in array.iter().flatten() { - let v_f64 = unsafe { std::mem::transmute_copy::(&v) }; - let sum_f64 = acc - .sum - .map(|s| unsafe { std::mem::transmute_copy::(&s) }) - .unwrap_or(0.0); - let new_sum = sum_f64 + v_f64; - acc.sum = Some(unsafe { std::mem::transmute_copy::(&new_sum) }); - } - Ok(()) -} - -fn update_decimal128( - acc: &mut TrySumAccumulator, - array: &PrimitiveArray, -) -> Result<()> { - let precision = acc.dec_precision.unwrap_or(38); - - for v in array.iter().flatten() { - let v_i128 = unsafe { std::mem::transmute_copy::(&v) }; - let sum_i128 = acc - .sum - .map(|s| unsafe { std::mem::transmute_copy::(&s) }); - - let new_sum = match sum_i128 { - None => v_i128, - Some(s) => match s.checked_add(v_i128) { - Some(result) => result, - None => { - acc.failed = true; - return Ok(()); - } - }, - }; - - if exceeds_decimal128_precision(new_sum, precision) { - acc.failed = true; - return Ok(()); - } - - acc.sum = Some(unsafe { std::mem::transmute_copy::(&new_sum) }); - } - Ok(()) -} - -fn evaluate_internal( - acc: &mut TrySumAccumulator, -) -> Result { - if acc.failed { - return ScalarValue::new_primitive::(None, &acc.data_type); - } - ScalarValue::new_primitive::(acc.sum, &acc.data_type) -} - -// Helpers to determine if it exceeds decimal precision -fn pow10_i128(p: u8) -> Option { - let mut v: i128 = 1; - for _ in 0..p { - v = v.checked_mul(10)?; - } - Some(v) -} - -fn exceeds_decimal128_precision(sum: i128, p: u8) -> bool { - if let Some(max_plus_one) = pow10_i128(p) { - let max = max_plus_one - 1; - sum > max || sum < -max - } else { - true - } -} - impl AggregateUDFImpl for SparkTrySum { fn as_any(&self) -> &dyn Any { self @@ -257,404 +55,18 @@ impl AggregateUDFImpl for SparkTrySum { } fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - - let dt = &arg_types[0]; - let result_type = match dt { - Null => Float64, - Decimal128(p, s) => { - let new_precision = DECIMAL128_MAX_PRECISION.min(p + 10); - Decimal128(new_precision, *s) - } - Int8 | Int16 | Int32 | Int64 => Int64, - Float16 | Float32 | Float64 => Float64, - - other => return exec_err!("try_sum: unsupported type: {other:?}"), - }; - - Ok(result_type) + self.inner.return_type(arg_types) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(TrySumAccumulator::<$t>::new($dt.clone()))) - }; - } - - match acc_args.return_field.data_type() { - DataType::Int64 => helper!(Int64Type, acc_args.return_field.data_type()), - DataType::Float64 => helper!(Float64Type, acc_args.return_field.data_type()), - DataType::Decimal128(_, _) => { - helper!(Decimal128Type, acc_args.return_field.data_type()) - } - _ => not_impl_err!( - "try_sum: unsupported type for accumulator: {}", - acc_args.return_field.data_type() - ), - } + self.inner.accumulator(acc_args) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let sum_dt = args.return_field.data_type().clone(); - Ok(vec![ - Field::new(format_state_name(args.name, "sum"), sum_dt, true).into(), - Field::new( - format_state_name(args.name, "failed"), - DataType::Boolean, - false, - ) - .into(), - ]) - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - use DataType::*; - if arg_types.len() != 1 { - return exec_err!( - "try_sum: exactly 1 argument expected, got {}", - arg_types.len() - ); - } - - let dt = &arg_types[0]; - let coerced = match dt { - Null => Float64, - Decimal128(p, s) => Decimal128(*p, *s), - Int8 | Int16 | Int32 | Int64 => Int64, - Float16 | Float32 | Float64 => Float64, - other => return exec_err!("try_sum: unsupported type: {other:?}"), - }; - Ok(vec![coerced]) - } - - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Null) - } -} - -#[cfg(test)] -mod tests { - use arrow::array::{BooleanArray, Decimal128Array, Float64Array, Int64Array}; - use datafusion_common::{DataFusionError, ScalarValue}; - use std::sync::Arc; - - use super::*; - // -------- Helpers -------- - - fn int64(values: Vec>) -> ArrayRef { - Arc::new(Int64Array::from(values)) as ArrayRef - } - - fn f64(values: Vec>) -> ArrayRef { - Arc::new(Float64Array::from(values)) as ArrayRef - } - - fn dec128(p: u8, s: i8, vals: Vec>) -> Result { - let base = Decimal128Array::from(vals); - let arr = base.with_precision_and_scale(p, s).map_err(|e| { - DataFusionError::Execution(format!("invalid precision/scale ({p},{s}): {e}")) - })?; - Ok(Arc::new(arr) as ArrayRef) - } - - // -------- update_batch + evaluate -------- - - #[test] - fn try_sum_int_basic() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Int64); - acc.update_batch(&[int64((0..10).map(Some).collect())])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Int64(Some(45))); - Ok(()) - } - - #[test] - fn try_sum_int_with_nulls() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Int64); - acc.update_batch(&[int64(vec![None, Some(2), Some(3), None, Some(5)])])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Int64(Some(10))); - Ok(()) - } - - #[test] - fn try_sum_float_basic() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Float64); - acc.update_batch(&[f64(vec![Some(1.5), Some(2.5), None, Some(3.0)])])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Float64(Some(7.0))); - Ok(()) - } - - #[test] - fn float_overflow_behaves_like_spark_sum_infinite() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Float64); - acc.update_batch(&[f64(vec![Some(1e308), Some(1e308)])])?; - - let out = acc.evaluate()?; - assert!( - matches!(out, ScalarValue::Float64(Some(v)) if v.is_infinite() && v.is_sign_positive()), - "waiting +Infinity, got: {out:?}" - ); - Ok(()) - } - - #[test] - fn try_sum_float_negative_zero_normalizes_to_positive_zero() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Float64); - // -0.0 + 0.0 should normalize to 0.0 (positive zero), not -0.0 - acc.update_batch(&[f64(vec![Some(-0.0), Some(0.0)])])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Float64(Some(0.0))); - // Verify it's positive zero using is_sign_positive - if let ScalarValue::Float64(Some(v)) = out { - assert!(v.is_sign_positive() || v == 0.0); - } - Ok(()) - } - - #[test] - fn try_sum_decimal_basic() -> Result<()> { - let p = 10u8; - let s = 2i8; - let mut acc = - TrySumAccumulator::::new(DataType::Decimal128(p, s)); - acc.update_batch(&[dec128(p, s, vec![Some(123), Some(477)])?])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Decimal128(Some(600), p, s)); - Ok(()) - } - - #[test] - fn try_sum_decimal_with_nulls() -> Result<()> { - let p = 10u8; - let s = 2i8; - let mut acc = - TrySumAccumulator::::new(DataType::Decimal128(p, s)); - acc.update_batch(&[dec128(p, s, vec![Some(150), None, Some(200)])?])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Decimal128(Some(350), p, s)); - Ok(()) - } - - #[test] - fn try_sum_decimal_overflow_sets_failed() -> Result<()> { - let p = 5u8; - let s = 0i8; - let mut acc = - TrySumAccumulator::::new(DataType::Decimal128(p, s)); - acc.update_batch(&[dec128(p, s, vec![Some(90_000), Some(20_000)])?])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Decimal128(None, p, s)); - assert!(acc.failed); - Ok(()) - } - - #[test] - fn try_sum_decimal_merge_ok_and_failure_propagation() -> Result<()> { - let p = 10u8; - let s = 2i8; - - let mut p_ok = - TrySumAccumulator::::new(DataType::Decimal128(p, s)); - p_ok.update_batch(&[dec128(p, s, vec![Some(100), Some(200)])?])?; - let s_ok = p_ok - .state()? - .into_iter() - .map(|sv| sv.to_array()) - .collect::>>()?; - - let mut p_fail = - TrySumAccumulator::::new(DataType::Decimal128(p, s)); - p_fail.update_batch(&[dec128(p, s, vec![Some(i128::MAX), Some(1)])?])?; - let s_fail = p_fail - .state()? - .into_iter() - .map(|sv| sv.to_array()) - .collect::>>()?; - - let mut final_acc = - TrySumAccumulator::::new(DataType::Decimal128(p, s)); - final_acc.merge_batch(&s_ok)?; - final_acc.merge_batch(&s_fail)?; - - assert!(final_acc.failed); - assert_eq!(final_acc.evaluate()?, ScalarValue::Decimal128(None, p, s)); - Ok(()) - } - - #[test] - fn try_sum_int_overflow_sets_failed() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Int64); - // i64::MAX + 1 => overflow => failed => result NULL - acc.update_batch(&[int64(vec![Some(i64::MAX), Some(1)])])?; - let out = acc.evaluate()?; - assert_eq!(out, ScalarValue::Int64(None)); - assert!(acc.failed); - Ok(()) - } - - #[test] - fn try_sum_int_negative_overflow_sets_failed() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Int64); - // i64::MIN - 1 → overflow negative - acc.update_batch(&[int64(vec![Some(i64::MIN), Some(-1)])])?; - assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); - assert!(acc.failed); - Ok(()) - } - - // -------- state + merge_batch -------- - - #[test] - fn try_sum_state_two_fields_and_merge_ok() -> Result<()> { - // acumulador 1 [10, 5] -> sum=15 - let mut acc1 = TrySumAccumulator::::new(DataType::Int64); - acc1.update_batch(&[int64(vec![Some(10), Some(5)])])?; - let state1 = acc1.state()?; // [sum, failed] - assert_eq!(state1.len(), 2); - - // acumulador 2 [20, NULL] -> sum=20 - let mut acc2 = TrySumAccumulator::::new(DataType::Int64); - acc2.update_batch(&[int64(vec![Some(20), None])])?; - let state2 = acc2.state()?; // [sum, failed] - - let state1_arrays: Vec = state1 - .into_iter() - .map(|sv| sv.to_array()) - .collect::>()?; - - let state2_arrays: Vec = state2 - .into_iter() - .map(|sv| sv.to_array()) - .collect::>()?; - - // final accumulator - let mut final_acc = TrySumAccumulator::::new(DataType::Int64); - - final_acc.merge_batch(&state1_arrays)?; - final_acc.merge_batch(&state2_arrays)?; - - // sum total = 15 + 20 = 35 - assert!(!final_acc.failed); - assert_eq!(final_acc.evaluate()?, ScalarValue::Int64(Some(35))); - Ok(()) - } - - #[test] - fn try_sum_merge_propagates_failure() -> Result<()> { - // sum=NULL, failed=true - let failed_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef; - let failed_flag = Arc::new(BooleanArray::from(vec![Some(true)])) as ArrayRef; - - let mut acc = TrySumAccumulator::::new(DataType::Int64); - acc.merge_batch(&[failed_sum, failed_flag])?; - - assert!(acc.failed); - assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); - Ok(()) - } - - #[test] - fn try_sum_merge_empty_partition_is_not_failure() -> Result<()> { - // sum=NULL, failed=false - let empty_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef; - let ok_flag = Arc::new(BooleanArray::from(vec![Some(false)])) as ArrayRef; - - let mut acc = TrySumAccumulator::::new(DataType::Int64); - acc.update_batch(&[int64(vec![Some(7), Some(8)])])?; // 15 - - acc.merge_batch(&[empty_sum, ok_flag])?; - - assert!(!acc.failed); - assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(15))); - Ok(()) - } - - // -------- signature -------- - - #[test] - fn try_sum_return_type_matches_input() -> Result<()> { - let f = SparkTrySum::new(); - assert_eq!(f.return_type(&[DataType::Int64])?, DataType::Int64); - assert_eq!(f.return_type(&[DataType::Float64])?, DataType::Float64); - Ok(()) - } - - #[test] - fn try_sum_state_and_evaluate_consistency() -> Result<()> { - let mut acc = TrySumAccumulator::::new(DataType::Float64); - acc.update_batch(&[f64(vec![Some(1.0), Some(2.0)])])?; - let eval = acc.evaluate()?; - let state = acc.state()?; - assert_eq!(state[0], eval); - assert_eq!(state[1], ScalarValue::Boolean(Some(false))); - Ok(()) - } - - // ------------------------- - // DECIMAL - // ------------------------- - - #[test] - fn decimal_10_2_sum_and_schema_widened() -> Result<()> { - // input: DECIMAL(10,2) -> result: DECIMAL(20,2) - let f = SparkTrySum::new(); - assert_eq!( - f.return_type(&[DataType::Decimal128(10, 2)])?, - DataType::Decimal128(20, 2), - "Spark needs +10 more digits of precision" - ); - - let mut acc = - TrySumAccumulator::::new(DataType::Decimal128(20, 2)); - acc.update_batch(&[dec128(10, 2, vec![Some(123), Some(477)])?])?; - assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(Some(600), 20, 2)); - Ok(()) - } - - #[test] - fn decimal_5_0_fits_after_widening() -> Result<()> { - // input: DECIMAL(5,0) -> result: DECIMAL(15,0) - let f = SparkTrySum::new(); - assert_eq!( - f.return_type(&[DataType::Decimal128(5, 0)])?, - DataType::Decimal128(15, 0) - ); - - let mut acc = - TrySumAccumulator::::new(DataType::Decimal128(15, 0)); - acc.update_batch(&[dec128(5, 0, vec![Some(90_000), Some(20_000)])?])?; - assert_eq!( - acc.evaluate()?, - ScalarValue::Decimal128(Some(110_000), 15, 0) - ); - Ok(()) - } - - #[test] - fn decimal_38_0_max_precision_overflows_to_null() -> Result<()> { - let f = SparkTrySum::new(); - assert_eq!( - f.return_type(&[DataType::Decimal128(38, 0)])?, - DataType::Decimal128(38, 0) - ); - let ten_pow_38_minus_1 = { - let p10 = pow10_i128(38) - .ok_or_else(|| DataFusionError::Internal("10^38 overflow".into()))?; - p10 - 1 - }; - let mut acc = - TrySumAccumulator::::new(DataType::Decimal128(38, 0)); - acc.update_batch(&[dec128(38, 0, vec![Some(ten_pow_38_minus_1), Some(1)])?])?; - - assert!(acc.failed, "need fail in overflow p=38"); - assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(None, 38, 0)); - Ok(()) + self.inner.state_fields(args) } }