Skip to content
Merged
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/catalog-listing/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool {
| Expr::Exists(_)
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::SetComparison(_)
| Expr::GroupingSet(_)
| Expr::Case(_) => Ok(TreeNodeRecursion::Continue),

Expand Down
1 change: 1 addition & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ insta = { workspace = true }
paste = { workspace = true }
rand = { workspace = true, features = ["small_rng"] }
rand_distr = "0.5"
recursive = { workspace = true }
regex = { workspace = true }
rstest = { workspace = true }
serde_json = { workspace = true }
Expand Down
193 changes: 193 additions & 0 deletions datafusion/core/tests/set_comparison.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::prelude::SessionContext;
use datafusion_common::{Result, assert_batches_eq, assert_contains};

fn build_table(values: &[i32]) -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
let array =
Arc::new(Int32Array::from(values.to_vec())) as Arc<dyn arrow::array::Array>;
RecordBatch::try_new(schema, vec![array]).map_err(Into::into)
}

#[tokio::test]
async fn set_comparison_any() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", build_table(&[1, 6, 10])?)?;
// Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly.
ctx.register_batch("s", {
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
let array = Arc::new(Int32Array::from(vec![Some(5), None]))
as Arc<dyn arrow::array::Array>;
RecordBatch::try_new(schema, vec![array])?
})?;

let df = ctx
.sql("select v from t where v > any(select v from s)")
.await?;
let results = df.collect().await?;

assert_batches_eq!(
&["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",],
&results
);
Ok(())
}

#[tokio::test]
async fn set_comparison_any_aggregate_subquery() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", build_table(&[1, 7])?)?;
ctx.register_batch("s", build_table(&[1, 2, 3])?)?;

let df = ctx
.sql(
"select v from t where v > any(select sum(v) from s group by v % 2) order by v",
)
.await?;
let results = df.collect().await?;

assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results);
Ok(())
}

#[tokio::test]
async fn set_comparison_all_empty() -> Result<()> {
let ctx = SessionContext::new();
Comment on lines +77 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding

async fn set_comparison_type_mismatch() -> Result<()> {
    // SELECT v FROM t WHERE v > ANY (SELECT s FROM strings)
    // INT > STRING should error with clear message

...

too?


ctx.register_batch("t", build_table(&[1, 6, 10])?)?;
ctx.register_batch(
"e",
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"v",
DataType::Int32,
true,
)]))),
)?;

let df = ctx
.sql("select v from t where v < all(select v from e)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tests for:

  • Multiple operators (=, !=, >=, <=)
  • NULL semantics (e.g., 5 != ALL (1, NULL)
    would improve test coverage

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kosiew, I added above cases in 07e23bd

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@waynexia
Thanks for the ping.
I will check again after you fix the clippy errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kosiew, CI is all green now, please take another look, thank you!

.await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+",
],
&results
);
Ok(())
}

#[tokio::test]
async fn set_comparison_type_mismatch() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", build_table(&[1])?)?;
ctx.register_batch("strings", {
let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")]))
as Arc<dyn arrow::array::Array>;
RecordBatch::try_new(schema, vec![array])?
})?;

let df = ctx
.sql("select v from t where v > any(select s from strings)")
.await?;
let err = df.collect().await.unwrap_err();
assert_contains!(
err.to_string(),
"expr type Int32 can't cast to Utf8 in SetComparison"
);
Ok(())
}

#[tokio::test]
async fn set_comparison_multiple_operators() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?;
ctx.register_batch("s", build_table(&[2, 3])?)?;

let df = ctx
.sql("select v from t where v = any(select v from s) order by v")
.await?;
let results = df.collect().await?;
assert_batches_eq!(
&["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",],
&results
);

let df = ctx
.sql("select v from t where v != all(select v from s) order by v")
.await?;
let results = df.collect().await?;
assert_batches_eq!(
&["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",],
&results
);

let df = ctx
.sql("select v from t where v >= all(select v from s) order by v")
.await?;
let results = df.collect().await?;
assert_batches_eq!(
&["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",],
&results
);

let df = ctx
.sql("select v from t where v <= any(select v from s) order by v")
.await?;
let results = df.collect().await?;
assert_batches_eq!(
&[
"+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+",
],
&results
);
Ok(())
}

#[tokio::test]
async fn set_comparison_null_semantics_all() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", build_table(&[5])?)?;
ctx.register_batch("s", {
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
let array = Arc::new(Int32Array::from(vec![Some(1), None]))
as Arc<dyn arrow::array::Array>;
RecordBatch::try_new(schema, vec![array])?
})?;

let df = ctx
.sql("select v from t where v != all(select v from s)")
.await?;
let results = df.collect().await?;
let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(0, row_count);
Ok(())
}
4 changes: 4 additions & 0 deletions datafusion/core/tests/sql/unparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_physical_plan::ExecutionPlanProperties;
use datafusion_sql::unparser::Unparser;
use datafusion_sql::unparser::dialect::DefaultDialect;
use itertools::Itertools;
use recursive::{set_minimum_stack_size, set_stack_allocation_size};

/// Paths to benchmark query files (supports running from repo root or different working directories).
const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"];
Expand Down Expand Up @@ -458,5 +459,8 @@ async fn test_clickbench_unparser_roundtrip() {

#[tokio::test]
async fn test_tpch_unparser_roundtrip() {
// Grow stacker segments earlier to avoid deep unparser recursion overflow in q20.
set_minimum_stack_size(512 * 1024);
set_stack_allocation_size(8 * 1024 * 1024);
Comment on lines +462 to +464
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case got stack overflow after I rebased against the latest main branch, and this is a workaround. I'm thinking of rewriting unparser to avoid deep nested recursion, to some iterative flavor in the follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create an issue to track this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created #19787

run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await;
}
97 changes: 97 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ pub enum Expr {
Exists(Exists),
/// IN subquery
InSubquery(InSubquery),
/// Set comparison subquery (e.g. `= ANY`, `> ALL`)
SetComparison(SetComparison),
/// Scalar subquery
ScalarSubquery(Subquery),
/// Represents a reference to all available fields in a specific schema,
Expand Down Expand Up @@ -1101,6 +1103,54 @@ impl Exists {
}
}

/// Whether the set comparison uses `ANY`/`SOME` or `ALL`
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub enum SetQuantifier {
/// `ANY` (or `SOME`)
Any,
/// `ALL`
All,
}

impl Display for SetQuantifier {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
SetQuantifier::Any => write!(f, "ANY"),
SetQuantifier::All => write!(f, "ALL"),
}
}
}

/// Set comparison subquery (e.g. `= ANY`, `> ALL`)
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct SetComparison {
/// The expression to compare
pub expr: Box<Expr>,
/// Subquery that will produce a single column of data to compare against
pub subquery: Subquery,
/// Comparison operator (e.g. `=`, `>`, `<`)
pub op: Operator,
/// Quantifier (`ANY`/`ALL`)
pub quantifier: SetQuantifier,
}

impl SetComparison {
/// Create a new set comparison expression
pub fn new(
expr: Box<Expr>,
subquery: Subquery,
op: Operator,
quantifier: SetQuantifier,
) -> Self {
Self {
expr,
subquery,
op,
quantifier,
}
}
}

/// InList expression
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct InList {
Expand Down Expand Up @@ -1503,6 +1553,7 @@ impl Expr {
Expr::GroupingSet(..) => "GroupingSet",
Expr::InList { .. } => "InList",
Expr::InSubquery(..) => "InSubquery",
Expr::SetComparison(..) => "SetComparison",
Expr::IsNotNull(..) => "IsNotNull",
Expr::IsNull(..) => "IsNull",
Expr::Like { .. } => "Like",
Expand Down Expand Up @@ -2058,6 +2109,7 @@ impl Expr {
| Expr::GroupingSet(..)
| Expr::InList(..)
| Expr::InSubquery(..)
| Expr::SetComparison(..)
| Expr::IsFalse(..)
| Expr::IsNotFalse(..)
| Expr::IsNotNull(..)
Expand Down Expand Up @@ -2651,6 +2703,16 @@ impl HashNode for Expr {
subquery.hash(state);
negated.hash(state);
}
Expr::SetComparison(SetComparison {
expr: _,
subquery,
op,
quantifier,
}) => {
subquery.hash(state);
op.hash(state);
quantifier.hash(state);
}
Expr::ScalarSubquery(subquery) => {
subquery.hash(state);
}
Expand Down Expand Up @@ -2841,6 +2903,12 @@ impl Display for SchemaDisplay<'_> {
write!(f, "NOT IN")
}
Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"),
Expr::SetComparison(SetComparison {
expr,
op,
quantifier,
..
}) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())),
Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)),
Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)),
Expr::IsNotTrue(expr) => {
Expand Down Expand Up @@ -3316,6 +3384,12 @@ impl Display for Expr {
subquery,
negated: false,
}) => write!(f, "{expr} IN ({subquery:?})"),
Expr::SetComparison(SetComparison {
expr,
subquery,
op,
quantifier,
}) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"),
Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"),
Expr::BinaryExpr(expr) => write!(f, "{expr}"),
Expr::ScalarFunction(fun) => {
Expand Down Expand Up @@ -3799,6 +3873,7 @@ mod test {
}

use super::*;
use crate::logical_plan::{EmptyRelation, LogicalPlan};

#[test]
fn test_display_wildcard() {
Expand Down Expand Up @@ -3889,6 +3964,28 @@ mod test {
)
}

#[test]
fn test_display_set_comparison() {
let subquery = Subquery {
subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
})),
outer_ref_columns: vec![],
spans: Spans::new(),
};

let expr = Expr::SetComparison(SetComparison::new(
Box::new(Expr::Column(Column::from_name("a"))),
subquery,
Operator::Gt,
SetQuantifier::Any,
));

assert_eq!(format!("{expr}"), "a > ANY (<subquery>)");
assert_eq!(format!("{}", expr.human_display()), "a > ANY (<subquery>)");
}

#[test]
fn test_schema_display_alias_with_relation() {
assert_eq!(
Expand Down
Loading