Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions datafusion/functions-aggregate/src/percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ use arrow::{

use arrow::array::ArrowNativeTypeOp;

use crate::min_max::{max_udaf, min_udaf};
use datafusion_common::{
assert_eq_or_internal_err, internal_datafusion_err, plan_err, DataFusionError,
Result, ScalarValue,
assert_eq_or_internal_err, internal_datafusion_err, plan_err,
utils::take_function_args, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{AggregateFunction, Sort};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
expr::{AggregateFunction, Cast, Sort},
function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
simplify::SimplifyInfo,
};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
Volatility,
Expand Down Expand Up @@ -358,6 +362,12 @@ impl AggregateUDFImpl for PercentileCont {
}
}

fn simplify(&self) -> Option<AggregateFunctionSimplification> {
Some(Box::new(|aggregate_function, info| {
simplify_percentile_cont_aggregate(aggregate_function, info)
}))
}

fn supports_within_group_clause(&self) -> bool {
true
}
Expand All @@ -367,6 +377,88 @@ impl AggregateUDFImpl for PercentileCont {
}
}

#[derive(Clone, Copy)]
enum PercentileRewriteTarget {
Min,
Max,
}

#[expect(clippy::needless_pass_by_value)]
fn simplify_percentile_cont_aggregate(
aggregate_function: AggregateFunction,
info: &dyn SimplifyInfo,
) -> Result<Expr> {
let original_expr = Expr::AggregateFunction(aggregate_function.clone());
let params = &aggregate_function.params;

let [value, percentile] = take_function_args("percentile_cont", &params.args)?;

let is_descending = params
.order_by
.first()
.map(|sort| !sort.asc)
.unwrap_or(false);

let rewrite_target = match extract_percentile_literal(percentile) {
Some(0.0) => {
if is_descending {
PercentileRewriteTarget::Max
} else {
PercentileRewriteTarget::Min
}
}
Some(1.0) => {
if is_descending {
PercentileRewriteTarget::Min
} else {
PercentileRewriteTarget::Max
}
}
_ => return Ok(original_expr),
};

let input_type = match info.get_data_type(value) {
Ok(data_type) => data_type,
Err(_) => return Ok(original_expr),
};

let expected_return_type =
match percentile_cont_udaf().return_type(std::slice::from_ref(&input_type)) {
Ok(data_type) => data_type,
Err(_) => return Ok(original_expr),
};

let mut agg_arg = value.clone();
if expected_return_type != input_type {
// min/max return the same type as their input. percentile_cont widens
// integers to Float64 (and preserves float/decimal types), so ensure the
// rewritten aggregate sees an input of the final return type.
agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), expected_return_type.clone()));
}

let udaf = match rewrite_target {
PercentileRewriteTarget::Min => min_udaf(),
PercentileRewriteTarget::Max => max_udaf(),
};

let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
udaf,
vec![agg_arg],
params.distinct,
params.filter.clone(),
vec![],
params.null_treatment,
));
Ok(rewritten)
}

fn extract_percentile_literal(expr: &Expr) -> Option<f64> {
match expr {
Expr::Literal(ScalarValue::Float64(Some(value)), _) => Some(*value),
_ => None,
}
}

/// The percentile_cont accumulator accumulates the raw input values
/// as native types.
///
Expand Down
53 changes: 53 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3488,6 +3488,59 @@ SELECT percentile_cont(1.0) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100
----
5

# Ensure percentile_cont simplification rewrites to min/max plans
query TT
EXPLAIN SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[min(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]]]
02)--TableScan: aggregate_test_100 projection=[c2]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]]
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true

query TT
EXPLAIN SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY c2 DESC) FROM aggregate_test_100;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[max(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]]]
02)--TableScan: aggregate_test_100 projection=[c2]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]]
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true

query TT
EXPLAIN SELECT percentile_cont(c2, 0.0) FROM aggregate_test_100;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[min(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(aggregate_test_100.c2,Float64(0))]]
02)--TableScan: aggregate_test_100 projection=[c2]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(0))]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(0))]
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true

query TT
EXPLAIN SELECT percentile_cont(c2, 1.0) FROM aggregate_test_100;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[max(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(aggregate_test_100.c2,Float64(1))]]
02)--TableScan: aggregate_test_100 projection=[c2]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(1))]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(1))]
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true

query R
SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100
----
Expand Down