Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 99 additions & 30 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use std::sync::Arc;

use arrow::array::{
ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
downcast_integer,
};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::{
Expand Down Expand Up @@ -127,12 +126,47 @@ impl AggregateUDFImpl for Median {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
// MEDIAN performs linear interpolation for even-length arrays and should return a float type
// For integer inputs, return Float64 (matching PostgreSQL/DuckDB/Spark behavior)
// For float/decimal inputs, preserve the input type
match &arg_types[0] {
DataType::Float16 | DataType::Float32 | DataType::Float64 => {
Ok(arg_types[0].clone())
}
DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _) => Ok(arg_types[0].clone()),
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => Ok(DataType::Float64),
dt => Err(DataFusionError::NotImplemented(format!(
"median does not support input type {dt}"
))),
}
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
//Intermediate state is a list of the elements we have collected so far
let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
let input_type = args.input_fields[0].data_type().clone();
// For integer types, we store as Float64 internally
let storage_type = match &input_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => DataType::Float64,
_ => input_type,
};
let field = Field::new_list_field(storage_type.clone(), true);
let state_name = if args.is_distinct {
"distinct_median"
} else {
Expand Down Expand Up @@ -166,20 +200,27 @@ impl AggregateUDFImpl for Median {
};
}

let dt = acc_args.expr_fields[0].data_type().clone();
downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
let input_dt = acc_args.expr_fields[0].data_type().clone();
match input_dt {
// For integer types, use Float64 internally since median returns Float64
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => helper!(Float64Type, DataType::Float64),
DataType::Float16 => helper!(Float16Type, input_dt),
DataType::Float32 => helper!(Float32Type, input_dt),
DataType::Float64 => helper!(Float64Type, input_dt),
DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianAccumulator not supported for {} with {}",
acc_args.name,
dt,
acc_args.name, input_dt,
))),
}
}
Expand All @@ -200,27 +241,34 @@ impl AggregateUDFImpl for Median {
num_args
);

let dt = args.expr_fields[0].data_type().clone();
let input_dt = args.expr_fields[0].data_type().clone();

macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
};
}

downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
match input_dt {
// For integer types, use Float64 internally since median returns Float64
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => helper!(Float64Type, DataType::Float64),
DataType::Float16 => helper!(Float16Type, input_dt),
DataType::Float32 => helper!(Float32Type, input_dt),
DataType::Float64 => helper!(Float64Type, input_dt),
DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianGroupsAccumulator not supported for {} with {}",
args.name,
dt,
args.name, input_dt,
))),
}
}
Expand Down Expand Up @@ -275,7 +323,14 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
// Cast to target type if needed (e.g., integer to Float64)
let values = if values[0].data_type() != &self.data_type {
arrow::compute::cast(&values[0], &self.data_type)?
} else {
Arc::clone(&values[0])
};

let values = values.as_primitive::<T>();
self.all_values.reserve(values.len() - values.null_count());
self.all_values.extend(values.iter().flatten());
Ok(())
Expand Down Expand Up @@ -367,7 +422,15 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();

// Cast to target type if needed (e.g., integer to Float64)
let values = if values[0].data_type() != &self.data_type {
arrow::compute::cast(&values[0], &self.data_type)?
} else {
Arc::clone(&values[0])
};

let values = values.as_primitive::<T>();

// Push the `not nulls + not filtered` row into its group
self.group_values.resize(total_num_groups, Vec::new());
Expand Down Expand Up @@ -558,7 +621,13 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.distinct_values.update_batch(values)
// Cast to target type if needed (e.g., integer to Float64)
let values = if values[0].data_type() != &self.data_type {
vec![arrow::compute::cast(&values[0], &self.data_type)?]
} else {
values.to_vec()
};
self.distinct_values.update_batch(&values)
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
Expand Down
26 changes: 13 additions & 13 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -874,13 +874,13 @@ statement error
SELECT approx_median(c1) FROM aggregate_test_100

# csv_query_median_1
query I
query R
SELECT median(c2) FROM aggregate_test_100
----
3

# csv_query_median_2
query I
query R
SELECT median(c6) FROM aggregate_test_100
----
1125553990140691277
Expand All @@ -892,18 +892,18 @@ SELECT median(c12) FROM aggregate_test_100
0.551390054439

# median_i8
query I
query R
SELECT median(col_i8) FROM median_table
----
-14

# distinct_median_i8
query I
query R
SELECT median(distinct col_i8) FROM median_table
----
100

query II
query RR
SELECT median(col_i8), median(distinct col_i8) FROM median_table
----
-14 100
Expand All @@ -925,43 +925,43 @@ query error DataFusion error: Error during planning: \[IGNORE \| RESPECT\] NULLS
SELECT median(c2) RESPECT NULLS FROM aggregate_test_100

# median_i16
query I
query R
SELECT median(col_i16) FROM median_table
----
-16334

# median_i32
query I
query R
SELECT median(col_i32) FROM median_table
----
-1073741774

# median_i64
query I
query R
SELECT median(col_i64) FROM median_table
----
-4611686018427387854

# median_u8
query I
query R
SELECT median(col_u8) FROM median_table
----
50

# median_u16
query I
query R
SELECT median(col_u16) FROM median_table
----
50

# median_u32
query I
query R
SELECT median(col_u32) FROM median_table
----
50

# median_u64
query I
query R
SELECT median(col_u64) FROM median_table
----
50
Expand Down Expand Up @@ -3699,7 +3699,7 @@ SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100
4

# Test that percentile_cont(0.5) equals median
query I
query R
SELECT median(c2) FROM aggregate_test_100
----
3
Expand Down
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
5 6449337880 7.074412226677

# Test median for int / float
query IIR
query IRR
SELECT c2, median(c5), median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
----
1 23971150 0.5922606
Expand Down Expand Up @@ -346,7 +346,7 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
5 -194 7.074412226677

# Test median with nullable fields
query IIR
query IRR
SELECT c2, median(c3), median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
----
1 12 0.6067944
Expand Down Expand Up @@ -472,7 +472,7 @@ FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
5 5 0

# Test median with filter
query III
query IRR
SELECT
c2,
median(c3) FILTER (WHERE c3 > 0),
Expand Down Expand Up @@ -607,7 +607,7 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
5 6

# Test median with nullable fields and filter
query IIR
query IRR
SELECT c2,
median(c3) FILTER (WHERE c5 > 0),
median(c11) FILTER (WHERE c5 < 0)
Expand All @@ -620,7 +620,7 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
5 -18 0.49842384

# Test min / max with nullable fields and nullable filter
query II
query IR
SELECT c2,
median(c3) FILTER (WHERE c11 > 0.5)
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
Expand Down
Loading