Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 164 additions & 15 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_expr::{
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr_common::signature::TypeSignature;
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
use regex::Regex;
Expand Down Expand Up @@ -945,6 +946,7 @@ struct ScalarFunctionWrapper {
expr: Expr,
signature: Signature,
return_type: DataType,
defaults: Vec<Option<Expr>>,
}

impl ScalarUDFImpl for ScalarFunctionWrapper {
Expand Down Expand Up @@ -973,15 +975,19 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let replacement = Self::replacement(&self.expr, &args)?;
let replacement = Self::replacement(&self.expr, &args, &self.defaults)?;

Ok(ExprSimplifyResult::Simplified(replacement))
}
}

impl ScalarFunctionWrapper {
// replaces placeholders with actual arguments
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
fn replacement(
expr: &Expr,
args: &[Expr],
defaults: &[Option<Expr>],
) -> Result<Expr> {
let result = expr.clone().transform(|e| {
let r = match e {
Expr::Placeholder(placeholder) => {
Expand All @@ -990,10 +996,13 @@ impl ScalarFunctionWrapper {
if placeholder_position < args.len() {
Transformed::yes(args[placeholder_position].clone())
} else {
exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?
match defaults[placeholder_position] {
Some(ref default) => Transformed::yes(default.clone()),
None => exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?,
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Guarded Access to Parameter Defaults

Potential out-of-bounds array access in ScalarFunctionWrapper::replacement(). The code accesses defaults[placeholder_position] without checking if placeholder_position < defaults.len(). If a function body references a placeholder (e.g., $3) that doesn't correspond to a defined parameter, this will panic with an out-of-bounds error. A bounds check should be added before accessing the defaults array.

Fix in Cursor Fix in Web

}
}
_ => Transformed::no(e),
Expand Down Expand Up @@ -1021,6 +1030,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
type Error = DataFusionError;

fn try_from(definition: CreateFunction) -> std::result::Result<Self, Self::Error> {
let args = definition.args.unwrap_or_default();
let defaults: Vec<Option<Expr>> =
args.iter().map(|a| a.default_expr.clone()).collect();
let signature: Signature = match defaults.iter().position(|v| v.is_some()) {
Some(pos) => {
let mut type_signatures: Vec<TypeSignature> = vec![];
// Generate all valid signatures
for n in pos..defaults.len() + 1 {
if n == 0 {
type_signatures.push(TypeSignature::Nullary)
} else {
type_signatures.push(TypeSignature::Exact(
args.iter().take(n).map(|a| a.data_type.clone()).collect(),
))
}
}
Signature::one_of(
type_signatures,
definition.params.behavior.unwrap_or(Volatility::Volatile),
)
}
None => Signature::exact(
args.iter().map(|a| a.data_type.clone()).collect(),
definition.params.behavior.unwrap_or(Volatility::Volatile),
),
};
Ok(Self {
name: definition.name,
expr: definition
Expand All @@ -1030,15 +1065,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
return_type: definition
.return_type
.expect("Return type has to be defined!"),
signature: Signature::exact(
definition
.args
.unwrap_or_default()
.into_iter()
.map(|a| a.data_type)
.collect(),
definition.params.behavior.unwrap_or(Volatility::Volatile),
),
signature,
defaults,
})
}
}
Expand Down Expand Up @@ -1112,6 +1140,127 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> {
let function_factory = Arc::new(CustomFunctionFactory::default());
let ctx = SessionContext::new().with_function_factory(function_factory.clone());

let sql = r#"
CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
RETURNS DOUBLE
RETURN $a + $b
"#;

assert!(ctx.sql(sql).await.is_ok());

let result = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await?;

assert_batches_eq!(
&[
"+-----------------------------------+",
"| better_add(Float64(2),Float64(2)) |",
"+-----------------------------------+",
"| 4.0 |",
"+-----------------------------------+",
],
&result
);

// cannot mix named and positional style
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
RETURNS DOUBLE
RETURN $1 + $b
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("cannot mix named and positional style");
let expected = "Error during planning: All function arguments must use either named or positional style.";
assert!(expected.starts_with(&err.strip_backtrace()));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Swapped Prefix Check for Error Messages

The arguments to starts_with() are reversed. The code checks expected.starts_with(&err.strip_backtrace()), which checks if the expected error message starts with the actual error. This should be err.strip_backtrace().starts_with(expected) to check if the actual error starts with the expected prefix, or use assert_eq! for exact comparison.

Fix in Cursor Fix in Web


Ok(())
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> {
let function_factory = Arc::new(CustomFunctionFactory::default());
let ctx = SessionContext::new().with_function_factory(function_factory.clone());

let sql = r#"
CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0)
RETURNS DOUBLE
RETURN $a + $b
"#;

assert!(ctx.sql(sql).await.is_ok());

// Check all function arity supported
let result = ctx.sql("select better_add()").await?.collect().await?;

assert_batches_eq!(
&[
"+--------------+",
"| better_add() |",
"+--------------+",
"| 4.0 |",
"+--------------+",
],
&result
);

let result = ctx.sql("select better_add(2.0)").await?.collect().await?;

assert_batches_eq!(
&[
"+------------------------+",
"| better_add(Float64(2)) |",
"+------------------------+",
"| 4.0 |",
"+------------------------+",
],
&result
);

let result = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await?;

assert_batches_eq!(
&[
"+-----------------------------------+",
"| better_add(Float64(2),Float64(2)) |",
"+-----------------------------------+",
"| 4.0 |",
"+-----------------------------------+",
],
&result
);

assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());

// non-default argument cannot follow default argument
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
RETURNS DOUBLE
RETURN $a + $b
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("non-default argument cannot follow default argument");
let expected =
"Error during planning: Non-default arguments cannot follow default arguments.";
assert!(expected.starts_with(&err.strip_backtrace()));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Reversed prefixes in starts_with assertion

The arguments to starts_with() are reversed. The code checks expected.starts_with(&err.strip_backtrace()), which checks if the expected error message starts with the actual error. This should be err.strip_backtrace().starts_with(expected) to check if the actual error starts with the expected prefix, or use assert_eq! for exact comparison.

Fix in Cursor Fix in Web

Ok(())
}

/// Saves whatever is passed to it as a scalar function
#[derive(Debug, Default)]
struct RecordingFunctionFactory {
Expand Down
21 changes: 16 additions & 5 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}

/// Create a placeholder expression
/// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then
/// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on.
/// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported.
fn create_placeholder_expr(
param: String,
param_data_types: &[FieldRef],
) -> Result<Expr> {
// Parse the placeholder as a number because it is the only support from sqlparser and postgres
// Try to parse the placeholder as a number. If the placeholder does not have a valid
// positional value, assume we have a named placeholder.
let index = param[1..].parse::<usize>();
let idx = match index {
Ok(0) => {
Expand All @@ -123,8 +123,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
return if param_data_types.is_empty() {
Ok(Expr::Placeholder(Placeholder::new_with_field(param, None)))
} else {
// when PREPARE Statement, param_data_types length is always 0
plan_err!("Invalid placeholder, not a number: {param}")
// FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but
// only CREATE FUNCTION currently supports named params. For now, we rewrite
// these to positional params.
let named_param_pos = param_data_types
.iter()
.position(|v| v.name() == &param[1..]);
match named_param_pos {
Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field(
format!("${}", pos + 1),
param_data_types.get(pos).cloned(),
))),
None => plan_err!("Unknown placeholder: {param}"),
}
};
}
};
Expand Down
38 changes: 37 additions & 1 deletion datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,28 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
None => None,
};
// Validate default arguments
let first_default = match args.as_ref() {
Some(arg) => arg.iter().position(|t| t.default_expr.is_some()),
None => None,
};
let last_non_default = match args.as_ref() {
Some(arg) => arg
.iter()
.rev()
.position(|t| t.default_expr.is_none())
.map(|reverse_pos| arg.len() - reverse_pos - 1),
None => None,
};
if let (Some(pos_default), Some(pos_non_default)) =
(first_default, last_non_default)
{
if pos_non_default > pos_default {
return plan_err!(
"Non-default arguments cannot follow default arguments."
);
}
}
// At the moment functions can't be qualified `schema.name`
let name = match &name.0[..] {
[] => exec_err!("Function should have name")?,
Expand All @@ -1233,9 +1255,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
//
let arg_types = args.as_ref().map(|arg| {
arg.iter()
.map(|t| Arc::new(Field::new("", t.data_type.clone(), true)))
.map(|t| {
let name = match t.name.clone() {
Some(name) => name.value,
None => "".to_string(),
};
Arc::new(Field::new(name, t.data_type.clone(), true))
})
.collect::<Vec<_>>()
});
// Validate parameter style
if let Some(ref fields) = arg_types {
let count_positional =
fields.iter().filter(|f| f.name() == "").count();
if !(count_positional == 0 || count_positional == fields.len()) {
return plan_err!("All function arguments must use either named or positional style.");
}
}
let mut planner_context = PlannerContext::new()
.with_prepare_param_data_types(arg_types.unwrap_or_default());

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn test_prepare_statement_to_plan_panic_param_format() {
assert_snapshot!(
logical_plan(sql).unwrap_err().strip_backtrace(),
@r###"
Error during planning: Invalid placeholder, not a number: $foo
Error during planning: Unknown placeholder: $foo
"###
);
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/prepare.slt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ statement error DataFusion error: SQL error: ParserError
PREPARE AS SELECT id, age FROM person WHERE age = $foo;

# param following a non-number, $foo, not supported
statement error Invalid placeholder, not a number: \$foo
statement error Unknown placeholder: \$foo
PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo;

# not specify table hence cannot specify columns
Expand Down
Loading