@@ -21,11 +21,11 @@ use crate::expr::{
2121 InSubquery , Placeholder , ScalarFunction , TryCast , Unnest , WindowFunction ,
2222 WindowFunctionParams ,
2323} ;
24- use crate :: type_coercion:: functions:: fields_with_udf;
24+ use crate :: type_coercion:: functions:: { UDFCoercionExt , fields_with_udf} ;
2525use crate :: udf:: ReturnFieldArgs ;
2626use crate :: { LogicalPlan , Projection , Subquery , WindowFunctionDefinition , utils} ;
2727use arrow:: compute:: can_cast_types;
28- use arrow:: datatypes:: { DataType , Field } ;
28+ use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2929use datafusion_common:: datatype:: FieldExt ;
3030use datafusion_common:: metadata:: FieldMetadata ;
3131use datafusion_common:: {
@@ -152,43 +152,10 @@ impl ExprSchemable for Expr {
152152 }
153153 }
154154 }
155- Expr :: ScalarFunction ( _func) => {
156- let return_type = self . to_field ( schema) ?. 1 . data_type ( ) . clone ( ) ;
157- Ok ( return_type)
158- }
159- Expr :: WindowFunction ( window_function) => self
160- . data_type_and_nullable_with_window_function ( schema, window_function)
161- . map ( |( return_type, _) | return_type) ,
162- Expr :: AggregateFunction ( AggregateFunction {
163- func,
164- params : AggregateFunctionParams { args, .. } ,
165- } ) => {
166- let fields = args
167- . iter ( )
168- . map ( |e| e. to_field ( schema) . map ( |( _, f) | f) )
169- . collect :: < Result < Vec < _ > > > ( ) ?;
170- let new_fields = fields_with_udf ( & fields, func. as_ref ( ) )
171- . map_err ( |err| {
172- let data_types = fields
173- . iter ( )
174- . map ( |f| f. data_type ( ) . clone ( ) )
175- . collect :: < Vec < _ > > ( ) ;
176- plan_datafusion_err ! (
177- "{} {}" ,
178- match err {
179- DataFusionError :: Plan ( msg) => msg,
180- err => err. to_string( ) ,
181- } ,
182- utils:: generate_signature_error_msg(
183- func. name( ) ,
184- func. signature( ) . clone( ) ,
185- & data_types
186- )
187- )
188- } ) ?
189- . into_iter ( )
190- . collect :: < Vec < _ > > ( ) ;
191- Ok ( func. return_field ( & new_fields) ?. data_type ( ) . clone ( ) )
155+ Expr :: ScalarFunction ( _)
156+ | Expr :: WindowFunction ( _)
157+ | Expr :: AggregateFunction ( _) => {
158+ Ok ( self . to_field ( schema) ?. 1 . data_type ( ) . clone ( ) )
192159 }
193160 Expr :: Not ( _)
194161 | Expr :: IsNull ( _)
@@ -348,21 +315,9 @@ impl ExprSchemable for Expr {
348315 }
349316 }
350317 Expr :: Cast ( Cast { expr, .. } ) => expr. nullable ( input_schema) ,
351- Expr :: ScalarFunction ( _func) => {
352- let field = self . to_field ( input_schema) ?. 1 ;
353-
354- let nullable = field. is_nullable ( ) ;
355- Ok ( nullable)
356- }
357- Expr :: AggregateFunction ( AggregateFunction { func, .. } ) => {
358- Ok ( func. is_nullable ( ) )
359- }
360- Expr :: WindowFunction ( window_function) => self
361- . data_type_and_nullable_with_window_function (
362- input_schema,
363- window_function,
364- )
365- . map ( |( _, nullable) | nullable) ,
318+ Expr :: ScalarFunction ( _)
319+ | Expr :: AggregateFunction ( _)
320+ | Expr :: WindowFunction ( _) => Ok ( self . to_field ( input_schema) ?. 1 . is_nullable ( ) ) ,
366321 Expr :: ScalarVariable ( field, _) => Ok ( field. is_nullable ( ) ) ,
367322 Expr :: TryCast { .. } | Expr :: Unnest ( _) | Expr :: Placeholder ( _) => Ok ( true ) ,
368323 Expr :: IsNull ( _)
@@ -534,73 +489,49 @@ impl ExprSchemable for Expr {
534489 ) ) )
535490 }
536491 Expr :: WindowFunction ( window_function) => {
537- let ( dt, nullable) = self . data_type_and_nullable_with_window_function (
538- schema,
539- window_function,
540- ) ?;
541- Ok ( Arc :: new ( Field :: new ( & schema_name, dt, nullable) ) )
542- }
543- Expr :: AggregateFunction ( aggregate_function) => {
544- let AggregateFunction {
545- func,
546- params : AggregateFunctionParams { args, .. } ,
492+ let WindowFunction {
493+ fun,
494+ params : WindowFunctionParams { args, .. } ,
547495 ..
548- } = aggregate_function ;
496+ } = window_function . as_ref ( ) ;
549497
550498 let fields = args
551499 . iter ( )
552500 . map ( |e| e. to_field ( schema) . map ( |( _, f) | f) )
553501 . collect :: < Result < Vec < _ > > > ( ) ?;
554- // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
555- let new_fields = fields_with_udf ( & fields, func. as_ref ( ) )
556- . map_err ( |err| {
557- let arg_types = fields
558- . iter ( )
559- . map ( |f| f. data_type ( ) )
560- . cloned ( )
561- . collect :: < Vec < _ > > ( ) ;
562- plan_datafusion_err ! (
563- "{} {}" ,
564- match err {
565- DataFusionError :: Plan ( msg) => msg,
566- err => err. to_string( ) ,
567- } ,
568- utils:: generate_signature_error_msg(
569- func. name( ) ,
570- func. signature( ) . clone( ) ,
571- & arg_types,
572- )
573- )
574- } ) ?
575- . into_iter ( )
576- . collect :: < Vec < _ > > ( ) ;
577-
502+ match fun {
503+ WindowFunctionDefinition :: AggregateUDF ( udaf) => {
504+ let new_fields =
505+ verify_function_arguments ( udaf. as_ref ( ) , & fields) ?;
506+ let return_field = udaf. return_field ( & new_fields) ?;
507+ Ok ( return_field)
508+ }
509+ WindowFunctionDefinition :: WindowUDF ( udwf) => {
510+ let new_fields =
511+ verify_function_arguments ( udwf. as_ref ( ) , & fields) ?;
512+ let return_field = udwf
513+ . field ( WindowUDFFieldArgs :: new ( & new_fields, & schema_name) ) ?;
514+ Ok ( return_field)
515+ }
516+ }
517+ }
518+ Expr :: AggregateFunction ( AggregateFunction {
519+ func,
520+ params : AggregateFunctionParams { args, .. } ,
521+ } ) => {
522+ let fields = args
523+ . iter ( )
524+ . map ( |e| e. to_field ( schema) . map ( |( _, f) | f) )
525+ . collect :: < Result < Vec < _ > > > ( ) ?;
526+ let new_fields = verify_function_arguments ( func. as_ref ( ) , & fields) ?;
578527 func. return_field ( & new_fields)
579528 }
580529 Expr :: ScalarFunction ( ScalarFunction { func, args } ) => {
581- let ( arg_types , fields) : ( Vec < DataType > , Vec < Arc < Field > > ) = args
530+ let fields = args
582531 . iter ( )
583532 . map ( |e| e. to_field ( schema) . map ( |( _, f) | f) )
584- . collect :: < Result < Vec < _ > > > ( ) ?
585- . into_iter ( )
586- . map ( |f| ( f. data_type ( ) . clone ( ) , f) )
587- . unzip ( ) ;
588- // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
589- let new_fields =
590- fields_with_udf ( & fields, func. as_ref ( ) ) . map_err ( |err| {
591- plan_datafusion_err ! (
592- "{} {}" ,
593- match err {
594- DataFusionError :: Plan ( msg) => msg,
595- err => err. to_string( ) ,
596- } ,
597- utils:: generate_signature_error_msg(
598- func. name( ) ,
599- func. signature( ) . clone( ) ,
600- & arg_types,
601- )
602- )
603- } ) ?;
533+ . collect :: < Result < Vec < _ > > > ( ) ?;
534+ let new_fields = verify_function_arguments ( func. as_ref ( ) , & fields) ?;
604535
605536 let arguments = args
606537 . iter ( )
@@ -678,6 +609,33 @@ impl ExprSchemable for Expr {
678609 }
679610}
680611
612+ /// Verify that function is invoked with correct number and type of arguments as
613+ /// defined in `TypeSignature`.
614+ fn verify_function_arguments < F : UDFCoercionExt > (
615+ function : & F ,
616+ input_fields : & [ FieldRef ] ,
617+ ) -> Result < Vec < FieldRef > > {
618+ fields_with_udf ( input_fields, function) . map_err ( |err| {
619+ let data_types = input_fields
620+ . iter ( )
621+ . map ( |f| f. data_type ( ) )
622+ . cloned ( )
623+ . collect :: < Vec < _ > > ( ) ;
624+ plan_datafusion_err ! (
625+ "{} {}" ,
626+ match err {
627+ DataFusionError :: Plan ( msg) => msg,
628+ err => err. to_string( ) ,
629+ } ,
630+ utils:: generate_signature_error_message(
631+ function. name( ) ,
632+ function. signature( ) ,
633+ & data_types
634+ )
635+ )
636+ } )
637+ }
638+
681639/// Returns the innermost [Expr] that is provably null if `expr` is null.
682640fn unwrap_certainly_null_expr ( expr : & Expr ) -> & Expr {
683641 match expr {
@@ -688,93 +646,6 @@ fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr {
688646 }
689647}
690648
691- impl Expr {
692- /// Common method for window functions that applies type coercion
693- /// to all arguments of the window function to check if it matches
694- /// its signature.
695- ///
696- /// If successful, this method returns the data type and
697- /// nullability of the window function's result.
698- ///
699- /// Otherwise, returns an error if there's a type mismatch between
700- /// the window function's signature and the provided arguments.
701- fn data_type_and_nullable_with_window_function (
702- & self ,
703- schema : & dyn ExprSchema ,
704- window_function : & WindowFunction ,
705- ) -> Result < ( DataType , bool ) > {
706- let WindowFunction {
707- fun,
708- params : WindowFunctionParams { args, .. } ,
709- ..
710- } = window_function;
711-
712- let fields = args
713- . iter ( )
714- . map ( |e| e. to_field ( schema) . map ( |( _, f) | f) )
715- . collect :: < Result < Vec < _ > > > ( ) ?;
716- match fun {
717- WindowFunctionDefinition :: AggregateUDF ( udaf) => {
718- let data_types = fields
719- . iter ( )
720- . map ( |f| f. data_type ( ) )
721- . cloned ( )
722- . collect :: < Vec < _ > > ( ) ;
723- let new_fields = fields_with_udf ( & fields, udaf. as_ref ( ) )
724- . map_err ( |err| {
725- plan_datafusion_err ! (
726- "{} {}" ,
727- match err {
728- DataFusionError :: Plan ( msg) => msg,
729- err => err. to_string( ) ,
730- } ,
731- utils:: generate_signature_error_msg(
732- fun. name( ) ,
733- fun. signature( ) ,
734- & data_types
735- )
736- )
737- } ) ?
738- . into_iter ( )
739- . collect :: < Vec < _ > > ( ) ;
740-
741- let return_field = udaf. return_field ( & new_fields) ?;
742-
743- Ok ( ( return_field. data_type ( ) . clone ( ) , return_field. is_nullable ( ) ) )
744- }
745- WindowFunctionDefinition :: WindowUDF ( udwf) => {
746- let data_types = fields
747- . iter ( )
748- . map ( |f| f. data_type ( ) )
749- . cloned ( )
750- . collect :: < Vec < _ > > ( ) ;
751- let new_fields = fields_with_udf ( & fields, udwf. as_ref ( ) )
752- . map_err ( |err| {
753- plan_datafusion_err ! (
754- "{} {}" ,
755- match err {
756- DataFusionError :: Plan ( msg) => msg,
757- err => err. to_string( ) ,
758- } ,
759- utils:: generate_signature_error_msg(
760- fun. name( ) ,
761- fun. signature( ) ,
762- & data_types
763- )
764- )
765- } ) ?
766- . into_iter ( )
767- . collect :: < Vec < _ > > ( ) ;
768- let ( _, function_name) = self . qualified_name ( ) ;
769- let field_args = WindowUDFFieldArgs :: new ( & new_fields, & function_name) ;
770-
771- udwf. field ( field_args)
772- . map ( |field| ( field. data_type ( ) . clone ( ) , field. is_nullable ( ) ) )
773- }
774- }
775- }
776- }
777-
778649/// Cast subquery in InSubquery/ScalarSubquery to a given type.
779650///
780651/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
0 commit comments