Skip to content

Commit

Permalink
chore(cubesql): Support PERCENTILE_CONT planning
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Sep 18, 2024
1 parent ab3771b commit f47d340
Show file tree
Hide file tree
Showing 27 changed files with 558 additions and 113 deletions.
14 changes: 7 additions & 7 deletions packages/cubejs-backend-native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -3193,6 +3193,8 @@ export class BaseQuery {
// DATEADD is being rewritten to DATE_ADD
// DATEADD: 'DATEADD({{ date_part }}, {{ interval }}, {{ args[2] }})',
DATE: 'DATE({{ args_concat }})',

PERCENTILECONT: 'PERCENTILE_CONT({{ args_concat }})',
},
statements: {
select: 'SELECT {% if distinct %}DISTINCT {% endif %}' +
Expand Down Expand Up @@ -3228,6 +3230,7 @@ export class BaseQuery {
like: '{{ expr }} {% if negated %}NOT {% endif %}LIKE {{ pattern }}',
ilike: '{{ expr }} {% if negated %}NOT {% endif %}ILIKE {{ pattern }}',
like_escape: '{{ like_expr }} ESCAPE {{ escape_char }}',
within_group: '{{ fun_sql }} WITHIN GROUP (ORDER BY {{ within_group_concat }})',
},
quotes: {
identifiers: '"',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ export class BigqueryQuery extends BaseQuery {
// templates.functions.DATEADD = 'DATETIME_ADD(CAST({{ args[2] }} AS DATETTIME), INTERVAL {{ interval }} {{ date_part }})';
templates.functions.CURRENTDATE = 'CURRENT_DATE';
delete templates.functions.TO_CHAR;
delete templates.functions.PERCENTILECONT;
templates.expressions.binary = '{% if op == \'%\' %}MOD({{ left }}, {{ right }}){% else %}({{ left }} {{ op }} {{ right }}){% endif %}';
templates.expressions.interval = 'INTERVAL {{ interval }}';
templates.expressions.extract = 'EXTRACT({% if date_part == \'DOW\' %}DAYOFWEEK{% elif date_part == \'DOY\' %}DAYOFYEAR{% else %}{{ date_part }}{% endif %} FROM {{ expr }})';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ export class ClickHouseQuery extends BaseQuery {
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
// TODO: Introduce additional filter in jinja? or parseDateTimeBestEffort?
// https://github.com/ClickHouse/ClickHouse/issues/19351
delete templates.functions.PERCENTILECONT;
templates.expressions.timestamp_literal = 'parseDateTimeBestEffort(\'{{ value }}\')';
delete templates.expressions.like_escape;
templates.quotes.identifiers = '`';
Expand Down
2 changes: 2 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/MssqlQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ export class MssqlQuery extends BaseQuery {
const templates = super.sqlTemplates();
templates.functions.LEAST = 'LEAST({{ args_concat }})';
templates.functions.GREATEST = 'GREATEST({{ args_concat }})';
// PERCENTILE_CONT works but requires PARTITION BY
delete templates.functions.PERCENTILECONT;
delete templates.expressions.ilike;
templates.types.string = 'VARCHAR';
templates.types.boolean = 'BIT';
Expand Down
2 changes: 2 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/MysqlQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ export class MysqlQuery extends BaseQuery {

public sqlTemplates() {
const templates = super.sqlTemplates();
// PERCENTILE_CONT works but requires PARTITION BY
delete templates.functions.PERCENTILECONT;
templates.quotes.identifiers = '`';
templates.quotes.escape = '\\`';
delete templates.expressions.ilike;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export class PrestodbQuery extends BaseQuery {
const templates = super.sqlTemplates();
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
delete templates.functions.PERCENTILECONT;
templates.statements.select = 'SELECT {{ select_concat | map(attribute=\'aliased\') | join(\', \') }} \n' +
'FROM (\n {{ from }}\n) AS {{ from_alias }} \n' +
'{% if group_by %} GROUP BY {{ group_by }}{% endif %}' +
Expand Down
14 changes: 7 additions & 7 deletions rust/cubesql/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions rust/cubesql/cubesql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ homepage = "https://cube.dev"

[dependencies]
arc-swap = "1"
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "dcf3e4aa26fd112043ef26fa4a78db5dbd443c86", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "11a4ed10b184b2f1b22f7458702ae0c63f011241", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
anyhow = "1.0"
thiserror = "1.0.50"
cubeclient = { path = "../cubeclient" }
pg-srv = { path = "../pg-srv" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" }
base64 = "0.13.0"
tokio = { version = "^1.35", features = ["full", "rt", "tracing"] }
serde = { version = "^1.0", features = ["derive"] }
Expand Down
34 changes: 25 additions & 9 deletions rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,31 @@ pub fn rewrite(expr: &Expr, map: &HashMap<Column, Option<Expr>>) -> Result<Optio
fun,
args,
distinct,
} => args
.iter()
.map(|arg| rewrite(arg, map))
.collect::<Result<Option<Vec<_>>>>()?
.map(|args| Expr::AggregateFunction {
fun: fun.clone(),
args,
distinct: distinct.clone(),
}),
within_group,
} => {
let args = args
.iter()
.map(|arg| rewrite(arg, map))
.collect::<Result<Option<Vec<_>>>>()?;
let within_group = match within_group.as_ref() {
Some(within_group) => within_group
.iter()
.map(|expr| rewrite(expr, map))
.collect::<Result<Option<Vec<_>>>>()?
.map(|within_group| Some(within_group)),
None => Some(None),
};
if let (Some(args), Some(within_group)) = (args, within_group) {
Some(Expr::AggregateFunction {
fun: fun.clone(),
args,
distinct: distinct.clone(),
within_group,
})
} else {
None
}
}
Expr::WindowFunction {
fun,
args,
Expand Down
19 changes: 18 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1921,8 +1921,10 @@ impl CubeScanWrapperNode {
fun,
args,
distinct,
within_group,
} => {
let mut sql_args = Vec::new();
let mut sql_within_group = Vec::new();
for arg in args {
if let AggregateFunction::Count = fun {
if !distinct {
Expand All @@ -1944,10 +1946,25 @@ impl CubeScanWrapperNode {
sql_query = query;
sql_args.push(sql);
}
if let Some(within_group) = within_group {
for expr in within_group {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
expr,
ungrouped_scan_node.clone(),
subqueries.clone(),
)
.await?;
sql_query = query;
sql_within_group.push(sql);
}
}
Ok((
sql_generator
.get_sql_templates()
.aggregate_function(fun, sql_args, distinct)
.aggregate_function(fun, sql_args, distinct, sql_within_group)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for aggregate function: {}",
Expand Down
32 changes: 32 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18588,4 +18588,36 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),

Ok(())
}

#[tokio::test]
async fn test_within_group_push_down() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
r#"
SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY taxful_total_price) AS pc
FROM KibanaSampleDataEcommerce
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("WITHIN GROUP (ORDER BY"));

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}
}
23 changes: 22 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ impl LogicalPlanToLanguageConverter {
fun,
args,
distinct,
within_group,
} => {
let fun = add_expr_data_node!(graph, fun, AggregateFunctionExprFun);
let args = add_expr_list_node!(
Expand All @@ -434,8 +435,18 @@ impl LogicalPlanToLanguageConverter {
flat_list
);
let distinct = add_expr_data_node!(graph, distinct, AggregateFunctionExprDistinct);
let within_group = add_expr_list_node!(
graph,
within_group.as_ref().unwrap_or(&vec![]),
query_params,
AggregateFunctionExprWithinGroup,
flat_list
);
graph.add(LogicalPlanLanguage::AggregateFunctionExpr([
fun, args, distinct,
fun,
args,
distinct,
within_group,
]))
}
Expr::WindowFunction {
Expand Down Expand Up @@ -1145,10 +1156,20 @@ pub fn node_to_expr(
let args =
match_expr_list_node!(node_by_id, to_expr, params[1], AggregateFunctionExprArgs);
let distinct = match_data_node!(node_by_id, params[2], AggregateFunctionExprDistinct);
let within_group = match_expr_list_node!(
node_by_id,
to_expr,
params[3],
AggregateFunctionExprWithinGroup
);
Expr::AggregateFunction {
fun,
args,
distinct,
within_group: match within_group.len() {
0 => None,
_ => Some(within_group),
},
}
}
LogicalPlanLanguage::WindowFunctionExpr(params) => {
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ macro_rules! variant_field_struct {
AggregateFunction::ApproxMedian => "ApproxMedian",
AggregateFunction::BoolAnd => "BoolAnd",
AggregateFunction::BoolOr => "BoolOr",
AggregateFunction::PercentileCont => "PercentileCont",
}
);
};
Expand Down
Loading

0 comments on commit f47d340

Please sign in to comment.