diff --git a/datafusion/core/tests/dataframe/describe.rs b/datafusion/core/tests/dataframe/describe.rs index c61fe4fed1615..9aa8a49c97ae3 100644 --- a/datafusion/core/tests/dataframe/describe.rs +++ b/datafusion/core/tests/dataframe/describe.rs @@ -44,7 +44,7 @@ async fn describe() -> Result<()> { | std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 | | min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 | | max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 | - | median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 | + | median | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.5 | 7.0 | +------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+ "); Ok(()) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c09db371912b0..9cbe6fe4fe7b7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1102,26 +1102,26 @@ async fn window_using_aggregates() -> Result<()> { | first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | | | | | | | | 1 | -85 | - | -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 | - | -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 | - | -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 | - | -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 | - | -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 | - | -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 | - | -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 | - | -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 | - | -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 | - | -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 | - | -85 | -56 | 2 | -70 | -70 | -56 | -85 | 1 | -25 | - | -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 | - | -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 | - | -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 | - | -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 | - | -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 | - | -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 | - | -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 | - | -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 | - | -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 | + | -85 | -101 | 14 | -12 | -12.0 | 83 | -101 | 4 | -54 | + | -85 | -101 | 17 | -25 | -25.0 | 83 | -101 | 5 | -31 | + | -85 | -12 | 10 | -32 | -34.0 | 83 | -85 | 3 | 13 | + | -85 | -25 | 3 | -56 | -56.0 | -25 | -85 | 1 | -5 | + | -85 | -31 | 18 | -29 | -28.0 | 83 | -101 | 5 | 36 | + | -85 | -38 | 16 | -25 | -25.0 | 83 | -101 | 4 | 65 | + | -85 | -43 | 7 | -43 | -43.0 | 83 | -85 | 2 | 45 | + | -85 | -48 | 6 | -35 | -36.5 | 83 | -85 | 2 | -43 | + | -85 | -5 | 4 | -37 | -40.5 | -5 | -85 | 1 | 83 | + | -85 | -54 | 15 | -17 | -18.5 | 83 | -101 | 4 | -38 | + | -85 | -56 | 2 | -70 | -70.5 | -56 | -85 | 1 | -25 | + | -85 | -72 | 9 | -43 | -43.0 | 83 | -85 | 3 | -12 | + | -85 | -85 | 1 | -85 | -85.0 | -85 | -85 | 1 | -56 | + | -85 | 13 | 11 | -17 | -18.5 | 83 | -85 | 3 | 14 | + | -85 | 13 | 11 | -25 | -25.0 | 83 | -85 | 3 | 13 | + | -85 | 14 | 12 | -12 | -12.0 | 83 | -85 | 3 | 17 | + | -85 | 17 | 13 | -11 | -8.5 | 83 | -85 | 4 | -101 | + | -85 | 45 | 8 | -34 | -34.0 | 83 | -85 | 3 | -72 | + | -85 | 65 | 17 | -17 | -18.5 | 83 | -101 | 5 | -101 | + | -85 | 83 | 5 | -25 | -25.0 | 83 | -85 | 2 | -48 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ " ); diff --git a/datafusion/core/tests/sql/aggregates/dict_nulls.rs b/datafusion/core/tests/sql/aggregates/dict_nulls.rs index f9e15a71a20f8..8733b9e87b57a 100644 --- a/datafusion/core/tests/sql/aggregates/dict_nulls.rs +++ b/datafusion/core/tests/sql/aggregates/dict_nulls.rs @@ -91,10 +91,10 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { +----------------+--------------+ | dict_null_vals | median_value | +----------------+--------------+ - | | 3 | - | group_x | 1 | - | group_y | 5 | - | group_z | 7 | + | | 3.0 | + | group_x | 1.0 | + | group_y | 5.0 | + | group_z | 7.0 | +----------------+--------------+ "); @@ -437,16 +437,16 @@ async fn test_median_distinct_with_fuzz_table_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results), @r" - +--------+---------------------+------+------+------+--------+--------+ - | u8_low | dictionary_utf8_low | col1 | col2 | col3 | col4 | col5 | - +--------+---------------------+------+------+------+--------+--------+ - | 50 | | | 30 | | 987.65 | 400000 | - | 50 | group_three | 5000 | 50 | 5000 | 555.55 | 500000 | - | 75 | | 4000 | | 4000 | | 450000 | - | 100 | group_one | 1100 | 11 | 1000 | 123.45 | 110000 | - | 100 | group_two | 1500 | 15 | 1500 | 111.11 | 150000 | - | 200 | | 2500 | 22 | 2500 | 506.11 | 250000 | - +--------+---------------------+------+------+------+--------+--------+ + +--------+---------------------+--------+------+--------+--------+----------+ + | u8_low | dictionary_utf8_low | col1 | col2 | col3 | col4 | col5 | + +--------+---------------------+--------+------+--------+--------+----------+ + | 50 | | | 30.0 | | 987.65 | 400000.0 | + | 50 | group_three | 5000.0 | 50.0 | 5000.0 | 555.55 | 500000.0 | + | 75 | | 4000.0 | | 4000.0 | | 450000.0 | + | 100 | group_one | 1100.0 | 11.0 | 1000.0 | 123.45 | 110000.0 | + | 100 | group_two | 1500.0 | 15.0 | 1500.0 | 111.11 | 150000.0 | + | 200 | | 2500.0 | 22.5 | 2500.0 | 506.11 | 250000.0 | + +--------+---------------------+--------+------+--------+--------+----------+ " ); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f137ae0801f09..2d0d3d7aad299 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{ ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder, - downcast_integer, }; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ @@ -34,10 +32,7 @@ use arrow::{ }; use arrow::array::Array; -use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef, -}; +use arrow::datatypes::{Decimal32Type, Decimal64Type, FieldRef}; use datafusion_common::{ DataFusionError, Result, ScalarValue, assert_eq_or_internal_err, @@ -55,6 +50,8 @@ use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_macros::user_doc; use std::collections::HashMap; +use crate::percentile_cont::calculate_percentile; + make_udaf_expr_and_func!( Median, median, @@ -127,12 +124,47 @@ impl AggregateUDFImpl for Median { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + // MEDIAN performs linear interpolation for even-length arrays and should return a float type + // For integer inputs, return Float64 (matching PostgreSQL/DuckDB/Spark behavior) + // For float/decimal inputs, preserve the input type + match &arg_types[0] { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + Ok(arg_types[0].clone()) + } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(DataType::Float64), + dt => Err(DataFusionError::NotImplemented(format!( + "median does not support input type {dt}" + ))), + } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true); + let input_type = args.input_fields[0].data_type().clone(); + // For integer types, we store as Float64 internally + let storage_type = match &input_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => DataType::Float64, + _ => input_type, + }; + let field = Field::new_list_field(storage_type.clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -166,20 +198,27 @@ impl AggregateUDFImpl for Median { }; } - let dt = acc_args.expr_fields[0].data_type().clone(); - downcast_integer! { - dt => (helper, dt), - DataType::Float16 => helper!(Float16Type, dt), - DataType::Float32 => helper!(Float32Type, dt), - DataType::Float64 => helper!(Float64Type, dt), - DataType::Decimal32(_, _) => helper!(Decimal32Type, dt), - DataType::Decimal64(_, _) => helper!(Decimal64Type, dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + let input_dt = acc_args.expr_fields[0].data_type().clone(); + match input_dt { + // For integer types, use Float64 internally since median returns Float64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => helper!(Float64Type, DataType::Float64), + DataType::Float16 => helper!(Float16Type, input_dt), + DataType::Float32 => helper!(Float32Type, input_dt), + DataType::Float64 => helper!(Float64Type, input_dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), _ => Err(DataFusionError::NotImplemented(format!( "MedianAccumulator not supported for {} with {}", - acc_args.name, - dt, + acc_args.name, input_dt, ))), } } @@ -200,7 +239,7 @@ impl AggregateUDFImpl for Median { num_args ); - let dt = args.expr_fields[0].data_type().clone(); + let input_dt = args.expr_fields[0].data_type().clone(); macro_rules! helper { ($t:ty, $dt:expr) => { @@ -208,19 +247,26 @@ impl AggregateUDFImpl for Median { }; } - downcast_integer! { - dt => (helper, dt), - DataType::Float16 => helper!(Float16Type, dt), - DataType::Float32 => helper!(Float32Type, dt), - DataType::Float64 => helper!(Float64Type, dt), - DataType::Decimal32(_, _) => helper!(Decimal32Type, dt), - DataType::Decimal64(_, _) => helper!(Decimal64Type, dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + match input_dt { + // For integer types, use Float64 internally since median returns Float64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => helper!(Float64Type, DataType::Float64), + DataType::Float16 => helper!(Float16Type, input_dt), + DataType::Float32 => helper!(Float32Type, input_dt), + DataType::Float64 => helper!(Float64Type, input_dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), _ => Err(DataFusionError::NotImplemented(format!( "MedianGroupsAccumulator not supported for {} with {}", - args.name, - dt, + args.name, input_dt, ))), } } @@ -275,7 +321,14 @@ impl Accumulator for MedianAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); + // Cast to target type if needed (e.g., integer to Float64) + let values = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let values = values.as_primitive::(); self.all_values.reserve(values.len() - values.null_count()); self.all_values.extend(values.iter().flatten()); Ok(()) @@ -290,7 +343,10 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&mut self) -> Result { - let median = calculate_median::(&mut self.all_values); + // Clone values since calculate_percentile modifies them in-place, + // and we need to preserve them for window functions that call evaluate() multiple times + let values = self.all_values.clone(); + let median = calculate_percentile::(values, 0.5); ScalarValue::new_primitive::(median, &self.data_type) } @@ -301,9 +357,14 @@ impl Accumulator for MedianAccumulator { fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let mut to_remove: HashMap = HashMap::new(); - let arr = &values[0]; + // Cast to target type if needed (e.g., integer to Float64) + let arr = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; for i in 0..arr.len() { - let v = ScalarValue::try_from_array(arr, i)?; + let v = ScalarValue::try_from_array(&arr, i)?; if !v.is_null() { *to_remove.entry(v).or_default() += 1; } @@ -367,7 +428,15 @@ impl GroupsAccumulator for MedianGroupsAccumulator Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); + + // Cast to target type if needed (e.g., integer to Float64) + let values = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let values = values.as_primitive::(); // Push the `not nulls + not filtered` row into its group self.group_values.resize(total_num_groups, Vec::new()); @@ -478,11 +547,11 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new().with_data_type(self.data_type.clone()); - for mut values in emit_group_values { - let median = calculate_median::(&mut values); + for values in emit_group_values { + let median = calculate_percentile::(values, 0.5); evaluate_result_builder.append_option(median); } @@ -496,7 +565,14 @@ impl GroupsAccumulator for MedianGroupsAccumulator Result> { assert_eq!(values.len(), 1, "one argument to merge_batch"); - let input_array = values[0].as_primitive::(); + // Cast to target type if needed (e.g., integer to Float64) + let values_array = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let input_array = values_array.as_primitive::(); // Directly convert the input array to states, each row will be // seen as a respective group. @@ -558,7 +634,13 @@ impl Accumulator for DistinctMedianAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.distinct_values.update_batch(values) + // Cast to target type if needed (e.g., integer to Float64) + let values = if values[0].data_type() != &self.data_type { + vec![arrow::compute::cast(&values[0], &self.data_type)?] + } else { + values.to_vec() + }; + self.distinct_values.update_batch(&values) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -566,11 +648,11 @@ impl Accumulator for DistinctMedianAccumulator { } fn evaluate(&mut self) -> Result { - let mut d = std::mem::take(&mut self.distinct_values.values) + let d = std::mem::take(&mut self.distinct_values.values) .into_iter() .map(|v| v.0) .collect::>(); - let median = calculate_median::(&mut d); + let median = calculate_percentile::(d, 0.5); ScalarValue::new_primitive::(median, &self.data_type) } @@ -578,54 +660,3 @@ impl Accumulator for DistinctMedianAccumulator { size_of_val(self) + self.distinct_values.size() } } - -/// Get maximum entry in the slice, -fn slice_max(array: &[T::Native]) -> T::Native -where - T: ArrowPrimitiveType, - T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison -{ - // Make sure that, array is not empty. - debug_assert!(!array.is_empty()); - // `.unwrap()` is safe here as the array is supposed to be non-empty - *array - .iter() - .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less)) - .unwrap() -} - -fn calculate_median(values: &mut [T::Native]) -> Option { - let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); - - let len = values.len(); - if len == 0 { - None - } else if len % 2 == 0 { - let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); - // Get the maximum of the low (left side after bi-partitioning) - let left_max = slice_max::(low); - // Calculate median as the average of the two middle values. - // Use checked arithmetic to detect overflow and fall back to safe formula. - let two = T::Native::usize_as(2); - let median = match left_max.add_checked(*high) { - Ok(sum) => sum.div_wrapping(two), - Err(_) => { - // Overflow detected - use safe midpoint formula: - // a/2 + b/2 + ((a%2 + b%2) / 2) - // This avoids overflow by dividing before adding. - let half_left = left_max.div_wrapping(two); - let half_right = (*high).div_wrapping(two); - let rem_left = left_max.mod_wrapping(two); - let rem_right = (*high).mod_wrapping(two); - // The sum of remainders (0, 1, or 2 for unsigned; -2 to 2 for signed) - // divided by 2 gives the correction factor (0 or 1 for unsigned; -1, 0, or 1 for signed) - let correction = rem_left.add_wrapping(rem_right).div_wrapping(two); - half_left.add_wrapping(half_right).add_wrapping(correction) - } - }; - Some(median) - } else { - let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); - Some(*median) - } -} diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index d6c8eabb459e6..6cbec83d2c308 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -67,7 +67,7 @@ use crate::utils::validate_percentile_expr; /// The interpolation formula: `lower + (upper - lower) * fraction` /// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION` /// to avoid floating-point operations on integer types while maintaining precision. -const INTERPOLATION_PRECISION: usize = 1_000_000; +pub(crate) const INTERPOLATION_PRECISION: usize = 1_000_000; create_func!(PercentileCont, percentile_cont_udaf); @@ -788,7 +788,9 @@ impl Accumulator for DistinctPercentileContAccumula /// For percentile p and n values: /// - If p * (n-1) is an integer, return the value at that position /// - Otherwise, interpolate between the two closest values -fn calculate_percentile( +/// +/// This function is also used by median (which is percentile_cont at 0.5). +pub(crate) fn calculate_percentile( mut values: Vec, percentile: f64, ) -> Option { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 2a4daeb92979d..7e185ad632954 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -874,16 +874,16 @@ statement error SELECT approx_median(c1) FROM aggregate_test_100 # csv_query_median_1 -query I +query R SELECT median(c2) FROM aggregate_test_100 ---- 3 # csv_query_median_2 -query I +query R SELECT median(c6) FROM aggregate_test_100 ---- -1125553990140691277 +1125553990140691200 # csv_query_median_3 query R @@ -892,18 +892,18 @@ SELECT median(c12) FROM aggregate_test_100 0.551390054439 # median_i8 -query I +query R SELECT median(col_i8) FROM median_table ---- -14 # distinct_median_i8 -query I +query R SELECT median(distinct col_i8) FROM median_table ---- 100 -query II +query RR SELECT median(col_i8), median(distinct col_i8) FROM median_table ---- -14 100 @@ -925,43 +925,43 @@ query error DataFusion error: Error during planning: \[IGNORE \| RESPECT\] NULLS SELECT median(c2) RESPECT NULLS FROM aggregate_test_100 # median_i16 -query I +query R SELECT median(col_i16) FROM median_table ---- -16334 # median_i32 -query I +query R SELECT median(col_i32) FROM median_table ---- -1073741774 # median_i64 -query I +query R SELECT median(col_i64) FROM median_table ---- --4611686018427387854 +-4611686018427388000 # median_u8 -query I +query R SELECT median(col_u8) FROM median_table ---- 50 # median_u16 -query I +query R SELECT median(col_u16) FROM median_table ---- 50 # median_u32 -query I +query R SELECT median(col_u32) FROM median_table ---- 50 # median_u64 -query I +query R SELECT median(col_u64) FROM median_table ---- 50 @@ -992,56 +992,56 @@ NaN # median_i8_overflow_negative -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(-85, 'Int8')), (arrow_cast(-56, 'Int8'))) AS t(v); ---- --70 +-70.5 # median_i8_overflow_positive # Test overflow with positive values: 100 + 120 = 220 > 127 (max i8) -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(100, 'Int8')), (arrow_cast(120, 'Int8'))) AS t(v); ---- 110 # median_u8_overflow # Test unsigned overflow: 200 + 250 = 450 > 255 (max u8) -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(200, 'UInt8')), (arrow_cast(250, 'UInt8'))) AS t(v); ---- 225 # median_i8_no_overflow_normal_case # Normal case that doesn't overflow for comparison -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(4, 'Int8')), (arrow_cast(5, 'Int8'))) AS t(v); ---- -4 +4.5 # median_i8_max_values # Test with both i8::MAX values: 127 + 127 = 254 > 127, overflow -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(127, 'Int8')), (arrow_cast(127, 'Int8'))) AS t(v); ---- 127 # median_i8_min_values # Test with both i8::MIN values: -128 + -128 = -256 < -128, underflow -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(-128, 'Int8')), (arrow_cast(-128, 'Int8'))) AS t(v); ---- -128 # median_i8_min_max_values -# Test with i8::MIN and i8::MAX: -128 + 127 = -1, no overflow, median = 0 (truncated from -0.5) -query I +# Test with i8::MIN and i8::MAX: -128 + 127 = -1, median = -0.5 +query R SELECT median(v) FROM (VALUES (arrow_cast(-128, 'Int8')), (arrow_cast(127, 'Int8'))) AS t(v); ---- -0 +-0.5 # median_u8_max_values # Test with both u8::MAX values: 255 + 255 = 510 > 255, overflow -query I +query R SELECT median(v) FROM (VALUES (arrow_cast(255, 'UInt8')), (arrow_cast(255, 'UInt8'))) AS t(v); ---- 255 @@ -1184,7 +1184,7 @@ drop table t; statement ok create table t(c int) as values (1), (2), (3), (4), (5); -query I +query R select median(c) from t; ---- 3 @@ -1196,10 +1196,10 @@ drop table t; statement ok create table t(c int) as values (1), (2), (3), (4), (5), (6); -query I +query R select median(c) from t; ---- -3 +3.5 statement ok drop table t; @@ -1208,10 +1208,10 @@ drop table t; statement ok create table t(c int) as values (1), (null), (3), (4), (5); -query I +query R select median(c) from t; ---- -3 +3.5 statement ok drop table t; @@ -1220,7 +1220,7 @@ drop table t; statement ok create table t(c int) as values (null), (null), (null); -query I +query R select median(c) from t; ---- NULL @@ -1232,7 +1232,7 @@ drop table t; statement ok create table t(c int unsigned) as values (1), (2), (3), (4), (5); -query I +query R select median(c) from t; ---- 3 @@ -1280,7 +1280,7 @@ drop table t; statement ok create table t(c int) as values (2), (1), (1), (2), (1), (3); -query I +query R select median(distinct c) from t; ---- 2 @@ -1292,7 +1292,7 @@ drop table t; statement ok create table t(c int) as values (1), (1), (3), (1), (1); -query I +query R select median(distinct c) from t; ---- 2 @@ -1304,7 +1304,7 @@ drop table t; statement ok create table t(c int) as values (1), (null), (1), (1), (3); -query I +query R select median(distinct c) from t; ---- 2 @@ -1316,7 +1316,7 @@ drop table t; statement ok create table t(c int unsigned) as values (1), (1), (2), (1), (3); -query I +query R select median(distinct c) from t; ---- 2 @@ -1328,7 +1328,7 @@ drop table t; statement ok create table t(c int unsigned) as values (1), (1), (1), (1), (3), (3); -query I +query R select median(distinct c) from t; ---- 2 @@ -1388,7 +1388,7 @@ drop table t; statement ok create table t(c int) as values (1), (1), (1), (1), (2), (2), (3), (3); -query I +query R select median(distinct c) from t; ---- 2 @@ -3699,7 +3699,7 @@ SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100 4 # Test that percentile_cont(0.5) equals median -query I +query R SELECT median(c2) FROM aggregate_test_100 ---- 3 @@ -5719,7 +5719,7 @@ statement ok drop table t; -query I +query R select median(a) from (select 1 as a where 1=0); ---- NULL @@ -7827,56 +7827,56 @@ drop table t; ####### # group median i8 non-nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_i8) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -14 group1 100 # group median i16 non-nullable -query TI +query TR SELECT col_group, median(col_i16) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -16334 group1 100 # group median i32 non-nullable -query TI +query TR SELECT col_group, median(col_i32) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -1073741774 group1 100 # group median i64 non-nullable -query TI +query TR SELECT col_group, median(col_i64) FROM group_median_table_non_nullable GROUP BY col_group ---- -group0 -4611686018427387854 +group0 -4611686018427388000 group1 100 # group median u8 non-nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_u8) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 # group median u16 non-nullable -query TI +query TR SELECT col_group, median(col_u16) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 # group median u32 non-nullable -query TI +query TR SELECT col_group, median(col_u32) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 # group median u64 non-nullable -query TI +query TR SELECT col_group, median(col_u64) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 @@ -7918,56 +7918,56 @@ group0 0.0002 group1 0.0003 # group median i8 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_i8) FROM group_median_table_nullable GROUP BY col_group ---- group0 -14 group1 100 # group median i16 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_i16) FROM group_median_table_nullable GROUP BY col_group ---- group0 -16334 group1 100 # group median i32 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_i32) FROM group_median_table_nullable GROUP BY col_group ---- group0 -1073741774 group1 100 # group median i64 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_i64) FROM group_median_table_nullable GROUP BY col_group ---- -group0 -4611686018427387854 +group0 -4611686018427388000 group1 100 # group median u8 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_u8) FROM group_median_table_nullable GROUP BY col_group ---- group0 50 group1 100 # group median u16 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_u16) FROM group_median_table_nullable GROUP BY col_group ---- group0 50 group1 100 # group median u32 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_u32) FROM group_median_table_nullable GROUP BY col_group ---- group0 50 group1 100 # group median u64 nullable -query TI rowsort +query TR rowsort SELECT col_group, median(col_u64) FROM group_median_table_nullable GROUP BY col_group ---- group0 50 @@ -8021,11 +8021,11 @@ create table group_median_all_nulls( ( 'group1', NULL), ( 'group1', NULL) -query TIT rowsort +query TRT rowsort SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP BY a ---- -group0 NULL Int32 -group1 NULL Int32 +group0 NULL Float64 +group1 NULL Float64 statement ok create table t_decimal (c decimal(10, 4)) as values (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null); diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 0885a6a7d663e..0af1186258a6d 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -130,7 +130,7 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5; -2117946883 d -2117946883 NULL NULL NULL -2098805236 c -2098805236 NULL NULL NULL -query ITIIII +query ITRRRR SELECT c5, c1, MEDIAN(c5), MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), @@ -265,7 +265,7 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 5 6449337880 7.074412226677 # Test median for int / float -query IIR +query IRR SELECT c2, median(c5), median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; ---- 1 23971150 0.5922606 @@ -346,14 +346,14 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 5 -194 7.074412226677 # Test median with nullable fields -query IIR +query IRR SELECT c2, median(c3), median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; ---- 1 12 0.6067944 2 1 0.46076488 3 14 0.40154034 -4 -17 0.48515016 -5 -35 0.5536642 +4 -17.5 0.48515016 +5 -35.5 0.5536642 # Test approx_median with nullable fields query IIR @@ -472,7 +472,7 @@ FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 5 5 0 # Test median with filter -query III +query IRR SELECT c2, median(c3) FILTER (WHERE c3 > 0), @@ -607,7 +607,7 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; 5 6 # Test median with nullable fields and filter -query IIR +query IRR SELECT c2, median(c3) FILTER (WHERE c5 > 0), median(c11) FILTER (WHERE c5 < 0) @@ -615,21 +615,21 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; ---- 1 -5 0.6623719 2 15 0.52930677 -3 13 0.32792538 +3 13.5 0.32792538 4 -38 0.49774808 5 -18 0.49842384 # Test min / max with nullable fields and nullable filter -query II +query IR SELECT c2, median(c3) FILTER (WHERE c11 > 0.5) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; ---- -1 33 +1 33.5 2 -29 3 22 4 -90 -5 -22 +5 -22.5 # Test approx_median with nullable fields and filter query IIR