Skip to content

Commit 3d63dd0

Browse files
committed
Further refactoring of type coercion function code
1 parent 0db668b commit 3d63dd0

File tree

10 files changed

+184
-334
lines changed

10 files changed

+184
-334
lines changed

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ impl OddCounter {
536536
impl SimpleWindowUDF {
537537
fn new(test_state: Arc<TestState>) -> Self {
538538
let signature =
539-
Signature::exact(vec![DataType::Float64], Volatility::Immutable);
539+
Signature::exact(vec![DataType::Int64], Volatility::Immutable);
540540
Self {
541541
signature,
542542
test_state: test_state.into(),

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ impl AggregateFunction {
953953
pub enum WindowFunctionDefinition {
954954
/// A user defined aggregate function
955955
AggregateUDF(Arc<AggregateUDF>),
956-
/// A user defined aggregate function
956+
/// A user defined window function
957957
WindowUDF(Arc<WindowUDF>),
958958
}
959959

datafusion/expr/src/expr_schema.rs

Lines changed: 68 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ use crate::expr::{
2121
InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
2222
WindowFunctionParams,
2323
};
24-
use crate::type_coercion::functions::fields_with_udf;
24+
use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf};
2525
use crate::udf::ReturnFieldArgs;
2626
use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils};
2727
use arrow::compute::can_cast_types;
28-
use arrow::datatypes::{DataType, Field};
28+
use arrow::datatypes::{DataType, Field, FieldRef};
2929
use datafusion_common::datatype::FieldExt;
3030
use datafusion_common::metadata::FieldMetadata;
3131
use datafusion_common::{
@@ -152,43 +152,10 @@ impl ExprSchemable for Expr {
152152
}
153153
}
154154
}
155-
Expr::ScalarFunction(_func) => {
156-
let return_type = self.to_field(schema)?.1.data_type().clone();
157-
Ok(return_type)
158-
}
159-
Expr::WindowFunction(window_function) => self
160-
.data_type_and_nullable_with_window_function(schema, window_function)
161-
.map(|(return_type, _)| return_type),
162-
Expr::AggregateFunction(AggregateFunction {
163-
func,
164-
params: AggregateFunctionParams { args, .. },
165-
}) => {
166-
let fields = args
167-
.iter()
168-
.map(|e| e.to_field(schema).map(|(_, f)| f))
169-
.collect::<Result<Vec<_>>>()?;
170-
let new_fields = fields_with_udf(&fields, func.as_ref())
171-
.map_err(|err| {
172-
let data_types = fields
173-
.iter()
174-
.map(|f| f.data_type().clone())
175-
.collect::<Vec<_>>();
176-
plan_datafusion_err!(
177-
"{} {}",
178-
match err {
179-
DataFusionError::Plan(msg) => msg,
180-
err => err.to_string(),
181-
},
182-
utils::generate_signature_error_msg(
183-
func.name(),
184-
func.signature().clone(),
185-
&data_types
186-
)
187-
)
188-
})?
189-
.into_iter()
190-
.collect::<Vec<_>>();
191-
Ok(func.return_field(&new_fields)?.data_type().clone())
155+
Expr::ScalarFunction(_)
156+
| Expr::WindowFunction(_)
157+
| Expr::AggregateFunction(_) => {
158+
Ok(self.to_field(schema)?.1.data_type().clone())
192159
}
193160
Expr::Not(_)
194161
| Expr::IsNull(_)
@@ -348,21 +315,9 @@ impl ExprSchemable for Expr {
348315
}
349316
}
350317
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
351-
Expr::ScalarFunction(_func) => {
352-
let field = self.to_field(input_schema)?.1;
353-
354-
let nullable = field.is_nullable();
355-
Ok(nullable)
356-
}
357-
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
358-
Ok(func.is_nullable())
359-
}
360-
Expr::WindowFunction(window_function) => self
361-
.data_type_and_nullable_with_window_function(
362-
input_schema,
363-
window_function,
364-
)
365-
.map(|(_, nullable)| nullable),
318+
Expr::ScalarFunction(_)
319+
| Expr::AggregateFunction(_)
320+
| Expr::WindowFunction(_) => Ok(self.to_field(input_schema)?.1.is_nullable()),
366321
Expr::ScalarVariable(field, _) => Ok(field.is_nullable()),
367322
Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true),
368323
Expr::IsNull(_)
@@ -534,73 +489,49 @@ impl ExprSchemable for Expr {
534489
)))
535490
}
536491
Expr::WindowFunction(window_function) => {
537-
let (dt, nullable) = self.data_type_and_nullable_with_window_function(
538-
schema,
539-
window_function,
540-
)?;
541-
Ok(Arc::new(Field::new(&schema_name, dt, nullable)))
542-
}
543-
Expr::AggregateFunction(aggregate_function) => {
544-
let AggregateFunction {
545-
func,
546-
params: AggregateFunctionParams { args, .. },
492+
let WindowFunction {
493+
fun,
494+
params: WindowFunctionParams { args, .. },
547495
..
548-
} = aggregate_function;
496+
} = window_function.as_ref();
549497

550498
let fields = args
551499
.iter()
552500
.map(|e| e.to_field(schema).map(|(_, f)| f))
553501
.collect::<Result<Vec<_>>>()?;
554-
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
555-
let new_fields = fields_with_udf(&fields, func.as_ref())
556-
.map_err(|err| {
557-
let arg_types = fields
558-
.iter()
559-
.map(|f| f.data_type())
560-
.cloned()
561-
.collect::<Vec<_>>();
562-
plan_datafusion_err!(
563-
"{} {}",
564-
match err {
565-
DataFusionError::Plan(msg) => msg,
566-
err => err.to_string(),
567-
},
568-
utils::generate_signature_error_msg(
569-
func.name(),
570-
func.signature().clone(),
571-
&arg_types,
572-
)
573-
)
574-
})?
575-
.into_iter()
576-
.collect::<Vec<_>>();
577-
502+
match fun {
503+
WindowFunctionDefinition::AggregateUDF(udaf) => {
504+
let new_fields =
505+
verify_function_arguments(udaf.as_ref(), &fields)?;
506+
let return_field = udaf.return_field(&new_fields)?;
507+
Ok(return_field)
508+
}
509+
WindowFunctionDefinition::WindowUDF(udwf) => {
510+
let new_fields =
511+
verify_function_arguments(udwf.as_ref(), &fields)?;
512+
let return_field = udwf
513+
.field(WindowUDFFieldArgs::new(&new_fields, &schema_name))?;
514+
Ok(return_field)
515+
}
516+
}
517+
}
518+
Expr::AggregateFunction(AggregateFunction {
519+
func,
520+
params: AggregateFunctionParams { args, .. },
521+
}) => {
522+
let fields = args
523+
.iter()
524+
.map(|e| e.to_field(schema).map(|(_, f)| f))
525+
.collect::<Result<Vec<_>>>()?;
526+
let new_fields = verify_function_arguments(func.as_ref(), &fields)?;
578527
func.return_field(&new_fields)
579528
}
580529
Expr::ScalarFunction(ScalarFunction { func, args }) => {
581-
let (arg_types, fields): (Vec<DataType>, Vec<Arc<Field>>) = args
530+
let fields = args
582531
.iter()
583532
.map(|e| e.to_field(schema).map(|(_, f)| f))
584-
.collect::<Result<Vec<_>>>()?
585-
.into_iter()
586-
.map(|f| (f.data_type().clone(), f))
587-
.unzip();
588-
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
589-
let new_fields =
590-
fields_with_udf(&fields, func.as_ref()).map_err(|err| {
591-
plan_datafusion_err!(
592-
"{} {}",
593-
match err {
594-
DataFusionError::Plan(msg) => msg,
595-
err => err.to_string(),
596-
},
597-
utils::generate_signature_error_msg(
598-
func.name(),
599-
func.signature().clone(),
600-
&arg_types,
601-
)
602-
)
603-
})?;
533+
.collect::<Result<Vec<_>>>()?;
534+
let new_fields = verify_function_arguments(func.as_ref(), &fields)?;
604535

605536
let arguments = args
606537
.iter()
@@ -678,6 +609,33 @@ impl ExprSchemable for Expr {
678609
}
679610
}
680611

612+
/// Verify that function is invoked with correct number and type of arguments as
613+
/// defined in `TypeSignature`.
614+
fn verify_function_arguments<F: UDFCoercionExt>(
615+
function: &F,
616+
input_fields: &[FieldRef],
617+
) -> Result<Vec<FieldRef>> {
618+
fields_with_udf(input_fields, function).map_err(|err| {
619+
let data_types = input_fields
620+
.iter()
621+
.map(|f| f.data_type())
622+
.cloned()
623+
.collect::<Vec<_>>();
624+
plan_datafusion_err!(
625+
"{} {}",
626+
match err {
627+
DataFusionError::Plan(msg) => msg,
628+
err => err.to_string(),
629+
},
630+
utils::generate_signature_error_message(
631+
function.name(),
632+
function.signature(),
633+
&data_types
634+
)
635+
)
636+
})
637+
}
638+
681639
/// Returns the innermost [Expr] that is provably null if `expr` is null.
682640
fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr {
683641
match expr {
@@ -688,93 +646,6 @@ fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr {
688646
}
689647
}
690648

691-
impl Expr {
692-
/// Common method for window functions that applies type coercion
693-
/// to all arguments of the window function to check if it matches
694-
/// its signature.
695-
///
696-
/// If successful, this method returns the data type and
697-
/// nullability of the window function's result.
698-
///
699-
/// Otherwise, returns an error if there's a type mismatch between
700-
/// the window function's signature and the provided arguments.
701-
fn data_type_and_nullable_with_window_function(
702-
&self,
703-
schema: &dyn ExprSchema,
704-
window_function: &WindowFunction,
705-
) -> Result<(DataType, bool)> {
706-
let WindowFunction {
707-
fun,
708-
params: WindowFunctionParams { args, .. },
709-
..
710-
} = window_function;
711-
712-
let fields = args
713-
.iter()
714-
.map(|e| e.to_field(schema).map(|(_, f)| f))
715-
.collect::<Result<Vec<_>>>()?;
716-
match fun {
717-
WindowFunctionDefinition::AggregateUDF(udaf) => {
718-
let data_types = fields
719-
.iter()
720-
.map(|f| f.data_type())
721-
.cloned()
722-
.collect::<Vec<_>>();
723-
let new_fields = fields_with_udf(&fields, udaf.as_ref())
724-
.map_err(|err| {
725-
plan_datafusion_err!(
726-
"{} {}",
727-
match err {
728-
DataFusionError::Plan(msg) => msg,
729-
err => err.to_string(),
730-
},
731-
utils::generate_signature_error_msg(
732-
fun.name(),
733-
fun.signature(),
734-
&data_types
735-
)
736-
)
737-
})?
738-
.into_iter()
739-
.collect::<Vec<_>>();
740-
741-
let return_field = udaf.return_field(&new_fields)?;
742-
743-
Ok((return_field.data_type().clone(), return_field.is_nullable()))
744-
}
745-
WindowFunctionDefinition::WindowUDF(udwf) => {
746-
let data_types = fields
747-
.iter()
748-
.map(|f| f.data_type())
749-
.cloned()
750-
.collect::<Vec<_>>();
751-
let new_fields = fields_with_udf(&fields, udwf.as_ref())
752-
.map_err(|err| {
753-
plan_datafusion_err!(
754-
"{} {}",
755-
match err {
756-
DataFusionError::Plan(msg) => msg,
757-
err => err.to_string(),
758-
},
759-
utils::generate_signature_error_msg(
760-
fun.name(),
761-
fun.signature(),
762-
&data_types
763-
)
764-
)
765-
})?
766-
.into_iter()
767-
.collect::<Vec<_>>();
768-
let (_, function_name) = self.qualified_name();
769-
let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name);
770-
771-
udwf.field(field_args)
772-
.map(|field| (field.data_type().clone(), field.is_nullable()))
773-
}
774-
}
775-
}
776-
}
777-
778649
/// Cast subquery in InSubquery/ScalarSubquery to a given type.
779650
///
780651
/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific

0 commit comments

Comments
 (0)