diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b804efe59106d..74856130b4506 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -23,10 +23,9 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue, internal_err}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -112,26 +111,11 @@ impl ScalarUDFImpl for NowFunc { internal_err!("return_field_from_args should be called instead") } - fn invoke_with_args( - &self, - _args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - internal_err!("invoke should not be called on a simplified now() function") - } - - fn simplify( - &self, - _args: Vec, - info: &dyn SimplifyInfo, - ) -> Result { - let now_ts = info - .execution_props() - .query_execution_start_time - .timestamp_nanos_opt(); - - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::TimestampNanosecond(now_ts, self.timezone.clone()), - None, + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + let now = chrono::Utc::now(); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(now.timestamp_nanos_opt().unwrap_or(0)), + self.timezone.clone(), ))) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 01de44cee1f60..e47b36ff6fb9f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -40,7 +40,7 @@ use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and, binary::BinaryTypeCoercer, lit, or, }; -use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; +use datafusion_expr::{Cast, TryCast}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ expr::{InList, InSubquery}, @@ -749,10 +749,49 @@ struct Simplifier<'a, S> { info: &'a S, } -impl<'a, S> Simplifier<'a, S> { +impl<'a, S: SimplifyInfo> Simplifier<'a, S> { pub fn new(info: &'a S) -> Self { Self { info } } + + fn simplify_scalar_function( + &self, + func: Arc, + args: Vec, + ) -> Result> { + if func.signature().volatility == Volatility::Volatile { + return Ok(Transformed::no(Expr::ScalarFunction( + ScalarFunction::new_udf(func, args), + ))); + } + + if !args.iter().all(|arg| matches!(arg, Expr::Literal(..))) { + return Ok(Transformed::no(Expr::ScalarFunction( + ScalarFunction::new_udf(func, args), + ))); + } + + let schema = Schema::new(Vec::::new()); + let df_schema = DFSchema::try_from(schema.clone())?; + let batch = RecordBatch::new_empty(Arc::new(schema)); + + let expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(&func), + args.clone(), + )); + + let phys_expr = + create_physical_expr(&expr, &df_schema, self.info.execution_props())?; + + let result = phys_expr.evaluate(&batch)?; + + match result { + ColumnarValue::Scalar(s) => Ok(Transformed::yes(Expr::Literal(s, None))), + ColumnarValue::Array(_) => Ok(Transformed::no(Expr::ScalarFunction( + ScalarFunction::new_udf(func, args), + ))), + } + } } impl TreeNodeRewriter for Simplifier<'_, S> { @@ -1569,16 +1608,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { .not(), ) } + Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { - match udf.simplify(args, info)? { - ExprSimplifyResult::Original(args) => { - Transformed::no(Expr::ScalarFunction(ScalarFunction { - func: udf, - args, - })) - } - ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), - } + self.simplify_scalar_function(udf, args)? } Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { @@ -5019,4 +5051,59 @@ mod tests { else_expr: None, }) } + + #[test] + fn test_simplify_scalar_udf_invoke() { + use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; + + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + struct ConstantUDF { + signature: Signature, + } + + impl ConstantUDF { + fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Int32], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for ConstantUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &str { + "constant_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) + } + } + + let udf = Arc::new(ScalarUDF::from(ConstantUDF::new())); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(udf, vec![lit(1)])); + + let schema = test_schema(); + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + + let simplified = simplifier.simplify(expr).unwrap(); + + assert_eq!(simplified, lit(100)); + } }