Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 7 additions & 23 deletions datafusion/functions/src/datetime/now.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ use std::sync::Arc;

use datafusion_common::config::ConfigOptions;
use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;

Expand Down Expand Up @@ -112,26 +111,11 @@ impl ScalarUDFImpl for NowFunc {
internal_err!("return_field_from_args should be called instead")
}

fn invoke_with_args(
&self,
_args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
internal_err!("invoke should not be called on a simplified now() function")
}

fn simplify(
&self,
_args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let now_ts = info
.execution_props()
.query_execution_start_time
.timestamp_nanos_opt();

Ok(ExprSimplifyResult::Simplified(Expr::Literal(
ScalarValue::TimestampNanosecond(now_ts, self.timezone.clone()),
None,
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let now = chrono::Utc::now();
Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
Some(now.timestamp_nanos_opt().unwrap_or(0)),
self.timezone.clone(),
)))
}

Expand Down
109 changes: 98 additions & 11 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion_expr::{
BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and,
binary::BinaryTypeCoercer, lit, or,
};
use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult};
use datafusion_expr::{Cast, TryCast};
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
use datafusion_expr::{
expr::{InList, InSubquery},
Expand Down Expand Up @@ -749,10 +749,49 @@ struct Simplifier<'a, S> {
info: &'a S,
}

impl<'a, S> Simplifier<'a, S> {
impl<'a, S: SimplifyInfo> Simplifier<'a, S> {
pub fn new(info: &'a S) -> Self {
Self { info }
}

fn simplify_scalar_function(
&self,
func: Arc<datafusion_expr::ScalarUDF>,
args: Vec<Expr>,
) -> Result<Transformed<Expr>> {
if func.signature().volatility == Volatility::Volatile {
return Ok(Transformed::no(Expr::ScalarFunction(
ScalarFunction::new_udf(func, args),
)));
}

if !args.iter().all(|arg| matches!(arg, Expr::Literal(..))) {
return Ok(Transformed::no(Expr::ScalarFunction(
ScalarFunction::new_udf(func, args),
)));
}

let schema = Schema::new(Vec::<Field>::new());
let df_schema = DFSchema::try_from(schema.clone())?;
let batch = RecordBatch::new_empty(Arc::new(schema));

let expr = Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::clone(&func),
args.clone(),
));

let phys_expr =
create_physical_expr(&expr, &df_schema, self.info.execution_props())?;

let result = phys_expr.evaluate(&batch)?;

match result {
ColumnarValue::Scalar(s) => Ok(Transformed::yes(Expr::Literal(s, None))),
ColumnarValue::Array(_) => Ok(Transformed::no(Expr::ScalarFunction(
ScalarFunction::new_udf(func, args),
))),
}
}
}

impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
Expand Down Expand Up @@ -1569,16 +1608,9 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
.not(),
)
}

Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
match udf.simplify(args, info)? {
ExprSimplifyResult::Original(args) => {
Transformed::no(Expr::ScalarFunction(ScalarFunction {
func: udf,
args,
}))
}
ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
}
self.simplify_scalar_function(udf, args)?
}

Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
Expand Down Expand Up @@ -5019,4 +5051,59 @@ mod tests {
else_expr: None,
})
}

#[test]
fn test_simplify_scalar_udf_invoke() {
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ConstantUDF {
signature: Signature,
}

impl ConstantUDF {
fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Int32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for ConstantUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"constant_udf"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}

fn invoke_with_args(
&self,
_args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}
}

let udf = Arc::new(ScalarUDF::from(ConstantUDF::new()));
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(udf, vec![lit(1)]));

let schema = test_schema();
let props = ExecutionProps::new();
let simplifier =
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));

let simplified = simplifier.simplify(expr).unwrap();

assert_eq!(simplified, lit(100));
}
}
Loading