-
Notifications
You must be signed in to change notification settings - Fork 0
18450: feat: support named variables & defaults for CREATE FUNCTION
#3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,7 @@ use datafusion_expr::{ | |
| LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, | ||
| ScalarUDF, ScalarUDFImpl, Signature, Volatility, | ||
| }; | ||
| use datafusion_expr_common::signature::TypeSignature; | ||
| use datafusion_functions_nested::range::range_udf; | ||
| use parking_lot::Mutex; | ||
| use regex::Regex; | ||
|
|
@@ -945,6 +946,7 @@ struct ScalarFunctionWrapper { | |
| expr: Expr, | ||
| signature: Signature, | ||
| return_type: DataType, | ||
| defaults: Vec<Option<Expr>>, | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for ScalarFunctionWrapper { | ||
|
|
@@ -973,15 +975,19 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { | |
| args: Vec<Expr>, | ||
| _info: &dyn SimplifyInfo, | ||
| ) -> Result<ExprSimplifyResult> { | ||
| let replacement = Self::replacement(&self.expr, &args)?; | ||
| let replacement = Self::replacement(&self.expr, &args, &self.defaults)?; | ||
|
|
||
| Ok(ExprSimplifyResult::Simplified(replacement)) | ||
| } | ||
| } | ||
|
|
||
| impl ScalarFunctionWrapper { | ||
| // replaces placeholders with actual arguments | ||
| fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> { | ||
| fn replacement( | ||
| expr: &Expr, | ||
| args: &[Expr], | ||
| defaults: &[Option<Expr>], | ||
| ) -> Result<Expr> { | ||
| let result = expr.clone().transform(|e| { | ||
| let r = match e { | ||
| Expr::Placeholder(placeholder) => { | ||
|
|
@@ -990,10 +996,13 @@ impl ScalarFunctionWrapper { | |
| if placeholder_position < args.len() { | ||
| Transformed::yes(args[placeholder_position].clone()) | ||
| } else { | ||
| exec_err!( | ||
| "Function argument {} not provided, argument missing!", | ||
| placeholder.id | ||
| )? | ||
| match defaults[placeholder_position] { | ||
| Some(ref default) => Transformed::yes(default.clone()), | ||
| None => exec_err!( | ||
| "Function argument {} not provided, argument missing!", | ||
| placeholder.id | ||
| )?, | ||
| } | ||
| } | ||
| } | ||
| _ => Transformed::no(e), | ||
|
|
@@ -1021,6 +1030,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper { | |
| type Error = DataFusionError; | ||
|
|
||
| fn try_from(definition: CreateFunction) -> std::result::Result<Self, Self::Error> { | ||
| let args = definition.args.unwrap_or_default(); | ||
| let defaults: Vec<Option<Expr>> = | ||
| args.iter().map(|a| a.default_expr.clone()).collect(); | ||
| let signature: Signature = match defaults.iter().position(|v| v.is_some()) { | ||
| Some(pos) => { | ||
| let mut type_signatures: Vec<TypeSignature> = vec![]; | ||
| // Generate all valid signatures | ||
| for n in pos..defaults.len() + 1 { | ||
| if n == 0 { | ||
| type_signatures.push(TypeSignature::Nullary) | ||
| } else { | ||
| type_signatures.push(TypeSignature::Exact( | ||
| args.iter().take(n).map(|a| a.data_type.clone()).collect(), | ||
| )) | ||
| } | ||
| } | ||
| Signature::one_of( | ||
| type_signatures, | ||
| definition.params.behavior.unwrap_or(Volatility::Volatile), | ||
| ) | ||
| } | ||
| None => Signature::exact( | ||
| args.iter().map(|a| a.data_type.clone()).collect(), | ||
| definition.params.behavior.unwrap_or(Volatility::Volatile), | ||
| ), | ||
| }; | ||
| Ok(Self { | ||
| name: definition.name, | ||
| expr: definition | ||
|
|
@@ -1030,15 +1065,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper { | |
| return_type: definition | ||
| .return_type | ||
| .expect("Return type has to be defined!"), | ||
| signature: Signature::exact( | ||
| definition | ||
| .args | ||
| .unwrap_or_default() | ||
| .into_iter() | ||
| .map(|a| a.data_type) | ||
| .collect(), | ||
| definition.params.behavior.unwrap_or(Volatility::Volatile), | ||
| ), | ||
| signature, | ||
| defaults, | ||
| }) | ||
| } | ||
| } | ||
|
|
@@ -1112,6 +1140,127 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> { | ||
| let function_factory = Arc::new(CustomFunctionFactory::default()); | ||
| let ctx = SessionContext::new().with_function_factory(function_factory.clone()); | ||
|
|
||
| let sql = r#" | ||
| CREATE FUNCTION better_add(a DOUBLE, b DOUBLE) | ||
| RETURNS DOUBLE | ||
| RETURN $a + $b | ||
| "#; | ||
|
|
||
| assert!(ctx.sql(sql).await.is_ok()); | ||
|
|
||
| let result = ctx | ||
| .sql("select better_add(2.0, 2.0)") | ||
| .await? | ||
| .collect() | ||
| .await?; | ||
|
|
||
| assert_batches_eq!( | ||
| &[ | ||
| "+-----------------------------------+", | ||
| "| better_add(Float64(2),Float64(2)) |", | ||
| "+-----------------------------------+", | ||
| "| 4.0 |", | ||
| "+-----------------------------------+", | ||
| ], | ||
| &result | ||
| ); | ||
|
|
||
| // cannot mix named and positional style | ||
| let bad_expression_sql = r#" | ||
| CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE) | ||
| RETURNS DOUBLE | ||
| RETURN $1 + $b | ||
| "#; | ||
| let err = ctx | ||
| .sql(bad_expression_sql) | ||
| .await | ||
| .expect_err("cannot mix named and positional style"); | ||
| let expected = "Error during planning: All function arguments must use either named or positional style."; | ||
| assert!(expected.starts_with(&err.strip_backtrace())); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Swapped Prefix Check for Error MessagesThe arguments to |
||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> { | ||
| let function_factory = Arc::new(CustomFunctionFactory::default()); | ||
| let ctx = SessionContext::new().with_function_factory(function_factory.clone()); | ||
|
|
||
| let sql = r#" | ||
| CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0) | ||
| RETURNS DOUBLE | ||
| RETURN $a + $b | ||
| "#; | ||
|
|
||
| assert!(ctx.sql(sql).await.is_ok()); | ||
|
|
||
| // Check all function arity supported | ||
| let result = ctx.sql("select better_add()").await?.collect().await?; | ||
|
|
||
| assert_batches_eq!( | ||
| &[ | ||
| "+--------------+", | ||
| "| better_add() |", | ||
| "+--------------+", | ||
| "| 4.0 |", | ||
| "+--------------+", | ||
| ], | ||
| &result | ||
| ); | ||
|
|
||
| let result = ctx.sql("select better_add(2.0)").await?.collect().await?; | ||
|
|
||
| assert_batches_eq!( | ||
| &[ | ||
| "+------------------------+", | ||
| "| better_add(Float64(2)) |", | ||
| "+------------------------+", | ||
| "| 4.0 |", | ||
| "+------------------------+", | ||
| ], | ||
| &result | ||
| ); | ||
|
|
||
| let result = ctx | ||
| .sql("select better_add(2.0, 2.0)") | ||
| .await? | ||
| .collect() | ||
| .await?; | ||
|
|
||
| assert_batches_eq!( | ||
| &[ | ||
| "+-----------------------------------+", | ||
| "| better_add(Float64(2),Float64(2)) |", | ||
| "+-----------------------------------+", | ||
| "| 4.0 |", | ||
| "+-----------------------------------+", | ||
| ], | ||
| &result | ||
| ); | ||
|
|
||
| assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err()); | ||
|
|
||
| // non-default argument cannot follow default argument | ||
| let bad_expression_sql = r#" | ||
| CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE) | ||
| RETURNS DOUBLE | ||
| RETURN $a + $b | ||
| "#; | ||
| let err = ctx | ||
| .sql(bad_expression_sql) | ||
| .await | ||
| .expect_err("non-default argument cannot follow default argument"); | ||
| let expected = | ||
| "Error during planning: Non-default arguments cannot follow default arguments."; | ||
| assert!(expected.starts_with(&err.strip_backtrace())); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Reversed prefixes in starts_with assertionThe arguments to |
||
| Ok(()) | ||
| } | ||
|
|
||
| /// Saves whatever is passed to it as a scalar function | ||
| #[derive(Debug, Default)] | ||
| struct RecordingFunctionFactory { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Guarded Access to Parameter Defaults
Potential out-of-bounds array access in
ScalarFunctionWrapper::replacement(). The code accessesdefaults[placeholder_position]without checking ifplaceholder_position < defaults.len(). If a function body references a placeholder (e.g.,$3) that doesn't correspond to a defined parameter, this will panic with an out-of-bounds error. A bounds check should be added before accessing thedefaultsarray.