Skip to content

Commit

Permalink
Normalize date_trunc and date_part tokens during initial AST rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
mcheshkov committed Aug 26, 2024
1 parent e88b8e9 commit 2b988b4
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 3 deletions.
83 changes: 83 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6931,6 +6931,89 @@ ORDER BY
Ok(())
}

// Tests that incoming query with 'qtr' (or another synonym) that is not reachable
// by any rewrites in egraph will be executable anyway
// TODO implement and test more complex queries, like dynamic granularity
#[tokio::test]
async fn test_nonrewritable_date_trunc() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let context = TestContext::new(DatabaseProtocol::PostgreSQL).await;

// language=PostgreSQL
let query = r#"
WITH count_by_month AS (
SELECT
DATE_TRUNC('month', dim_date0) month0,
COUNT(*) month_count
FROM MultiTypeCube
GROUP BY month0
)
SELECT
DATE_TRUNC('qtr', count_by_month.month0) quarter0,
MIN(month_count) min_month_count
FROM count_by_month
GROUP BY quarter0
"#;

let expected_cube_scan = V1LoadRequestQuery {
measures: Some(vec!["MultiTypeCube.count".to_string()]),
segments: Some(vec![]),
dimensions: Some(vec![]),
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
dimension: "MultiTypeCube.dim_date0".to_owned(),
granularity: Some("month".to_string()),
date_range: None,
}]),
order: None,
limit: None,
offset: None,
filters: None,
ungrouped: None,
};

context
.add_cube_load_mock(
expected_cube_scan.clone(),
simple_load_response(vec![
json!({
"MultiTypeCube.dim_date0.month": "2024-01-01T00:00:00",
"MultiTypeCube.count": "3",
}),
json!({
"MultiTypeCube.dim_date0.month": "2024-02-01T00:00:00",
"MultiTypeCube.count": "2",
}),
json!({
"MultiTypeCube.dim_date0.month": "2024-03-01T00:00:00",
"MultiTypeCube.count": "1",
}),
json!({
"MultiTypeCube.dim_date0.month": "2024-04-01T00:00:00",
"MultiTypeCube.count": "10",
}),
]),
)
.await;

assert_eq!(
context
.convert_sql_to_cube_query(&query)
.await
.unwrap()
.as_logical_plan()
.find_cube_scan()
.request,
expected_cube_scan
);

// Expect that query is executable, and properly groups months by quarter
insta::assert_snapshot!(context.execute_query(query).await.unwrap());
}

#[tokio::test]
async fn test_metabase_dow() -> Result<(), CubeError> {
let query_plan = convert_select_to_query_plan(
Expand Down
6 changes: 4 additions & 2 deletions rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use crate::{
sql::{
dataframe,
statement::{
ApproximateCountDistinctVisitor, CastReplacer, RedshiftDatePartReplacer,
SensitiveDataSanitizer, ToTimestampReplacer, UdfWildcardArgReplacer,
ApproximateCountDistinctVisitor, CastReplacer, DateTokenNormalizeReplacer,
RedshiftDatePartReplacer, SensitiveDataSanitizer, ToTimestampReplacer,
UdfWildcardArgReplacer,
},
ColumnFlags, ColumnType, Session, SessionManager, SessionState,
},
Expand Down Expand Up @@ -715,6 +716,7 @@ pub fn rewrite_statement(stmt: &ast::Statement) -> ast::Statement {
let stmt = CastReplacer::new().replace(stmt);
let stmt = ToTimestampReplacer::new().replace(&stmt);
let stmt = UdfWildcardArgReplacer::new().replace(&stmt);
let stmt = DateTokenNormalizeReplacer::new().replace(&stmt);
let stmt = RedshiftDatePartReplacer::new().replace(&stmt);
let stmt = ApproximateCountDistinctVisitor::new().replace(&stmt);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
source: cubesql/src/compile/mod.rs
expression: context.execute_query(query).await.unwrap()
---
+-------------------------+-----------------+
| quarter0 | min_month_count |
+-------------------------+-----------------+
| 2024-01-01T00:00:00.000 | 1 |
| 2024-04-01T00:00:00.000 | 10 |
+-------------------------+-----------------+
66 changes: 65 additions & 1 deletion rust/cubesql/cubesql/src/sql/statement.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::sql::shim::ConnectionError;
use crate::{compile::rewrite::rules::utils::DatePartToken, sql::shim::ConnectionError};
use itertools::Itertools;
use log::trace;
use pg_srv::{
Expand Down Expand Up @@ -802,6 +802,70 @@ impl<'ast> Visitor<'ast, ConnectionError> for CastReplacer {
}
}

// This approach is limited to literals-in-query, but it's better than nothing
// It would be simpler to do in rewrite rules, by relying on constant folding, but would require cumbersome top-down extraction
// TODO remove this if/when DF starts supporting all of PostgreSQL aliases
#[derive(Debug)]
pub struct DateTokenNormalizeReplacer {}

impl DateTokenNormalizeReplacer {
pub fn new() -> Self {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();

self.visit_statement(&mut result).unwrap();

result
}
}

impl<'ast> Visitor<'ast, ConnectionError> for DateTokenNormalizeReplacer {
// TODO support EXTRACT normalization after support in sqlparser
fn visit_function(&mut self, fun: &mut Function) -> Result<(), ConnectionError> {
for res in fun.name.0.iter_mut() {
self.visit_identifier(res)?;
}

let fn_name = fun.name.to_string().to_lowercase();
match (fn_name.as_str(), fun.args.len()) {
("date_trunc", 2) | ("date_part", 2) => {
match &mut fun.args[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
Value::SingleQuotedString(token),
))) => {
if let Ok(parsed) = token.parse::<DatePartToken>() {
*token = parsed.as_str().to_string();
} else {
// Do nothing
};

Check warning on line 843 in rust/cubesql/cubesql/src/sql/statement.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/statement.rs#L842-L843

Added lines #L842 - L843 were not covered by tests
}
_ => {
// Do nothing
}

Check warning on line 847 in rust/cubesql/cubesql/src/sql/statement.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/statement.rs#L845-L847

Added lines #L845 - L847 were not covered by tests
}
}
_ => {
// Do nothing
}
}

self.visit_function_args(&mut fun.args)?;
if let Some(over) = &mut fun.over {
for res in over.partition_by.iter_mut() {
self.visit_expr(res)?;
}
for order_expr in over.order_by.iter_mut() {
self.visit_expr(&mut order_expr.expr)?;
}
}

Ok(())
}
}

#[derive(Debug)]
pub struct RedshiftDatePartReplacer {}

Expand Down

0 comments on commit 2b988b4

Please sign in to comment.