1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- use std:: borrow:: Cow ;
1615use std:: hash:: Hash ;
1716use std:: ops:: Range ;
1817use std:: sync:: Arc ;
@@ -1160,65 +1159,54 @@ pub fn register(registry: &mut FunctionRegistry) {
11601159 ) ;
11611160}
11621161
1163- struct ArrayAggEvaluator {
1164- func : AggregateFunctionRef ,
1165- state_layout : Arc < StatesLayout > ,
1162+ struct ArrayAggEvaluator < ' a > {
1163+ func : & ' a AggregateFunctionRef ,
1164+ state_layout : & ' a StatesLayout ,
11661165 addr : StateAddr ,
11671166 need_manual_drop_state : bool ,
11681167 _arena : Bump ,
11691168}
11701169
1171- impl ArrayAggEvaluator {
1172- fn new ( func : AggregateFunctionRef , state_layout : Arc < StatesLayout > ) -> Self {
1173- let _arena = Bump :: new ( ) ;
1174- let addr = _arena. alloc_layout ( state_layout. layout ) . into ( ) ;
1175- let mut evaluator = Self {
1176- func,
1170+ impl < ' a > ArrayAggEvaluator < ' a > {
1171+ fn new ( func : & ' a AggregateFunctionRef , state_layout : & ' a StatesLayout ) -> Self {
1172+ let arena = Bump :: new ( ) ;
1173+ let addr = arena. alloc_layout ( state_layout. layout ) . into ( ) ;
1174+ Self {
11771175 state_layout,
11781176 addr,
1179- need_manual_drop_state : false ,
1180- _arena,
1181- } ;
1182- let state = evaluator. state ( ) ;
1183- evaluator. func . init_state ( state) ;
1184- evaluator. need_manual_drop_state = evaluator. func . need_manual_drop_state ( ) ;
1185- evaluator
1177+ need_manual_drop_state : func. need_manual_drop_state ( ) ,
1178+ func,
1179+ _arena : arena,
1180+ }
11861181 }
11871182
11881183 fn state ( & self ) -> AggrState {
11891184 AggrState :: new ( self . addr , & self . state_layout . states_loc [ 0 ] )
11901185 }
11911186
1192- fn reset_state ( & mut self ) {
1187+ fn eval_column ( & mut self , column : Column , builder : & mut ColumnBuilder ) -> Result < ( ) > {
11931188 let state = self . state ( ) ;
11941189 if self . need_manual_drop_state {
11951190 unsafe {
11961191 self . func . drop_state ( state) ;
11971192 }
11981193 }
11991194 self . func . init_state ( state) ;
1200- }
1201-
1202- fn eval_column ( & mut self , column : Column , builder : & mut ColumnBuilder ) -> Result < ( ) > {
1203- self . reset_state ( ) ;
12041195 let rows = column. len ( ) ;
1205- let entries = [ BlockEntry :: Column ( column) ] ;
1206- self . func
1207- . accumulate ( self . state ( ) , ( & entries) . into ( ) , None , rows) ?;
1208- self . func . merge_result ( self . state ( ) , false , builder) ?;
1196+ let entries = & [ BlockEntry :: Column ( column) ] ;
1197+ self . func . accumulate ( state, entries. into ( ) , None , rows) ?;
1198+ self . func . merge_result ( state, false , builder) ?;
12091199 Ok ( ( ) )
12101200 }
12111201}
12121202
1213- impl Drop for ArrayAggEvaluator {
1203+ impl Drop for ArrayAggEvaluator < ' _ > {
12141204 fn drop ( & mut self ) {
1215- let need_drop = self . need_manual_drop_state ;
1216- drop_guard ( move || {
1217- if need_drop {
1218- unsafe {
1219- self . func . drop_state ( self . state ( ) ) ;
1220- }
1221- }
1205+ if !self . need_manual_drop_state {
1206+ return ;
1207+ }
1208+ drop_guard ( move || unsafe {
1209+ self . func . drop_state ( self . state ( ) ) ;
12221210 } )
12231211 }
12241212}
@@ -1230,7 +1218,7 @@ struct ArrayAggDesc {
12301218}
12311219
12321220impl ArrayAggDesc {
1233- fn try_create ( name : & str , array_type : & DataType ) -> Result < Self > {
1221+ fn new ( name : & str , array_type : & DataType ) -> Result < Self > {
12341222 let factory = AggregateFunctionFactory :: instance ( ) ;
12351223 let func = factory. get ( name, vec ! [ ] , vec ! [ array_type. clone( ) ] , vec ! [ ] ) ?;
12361224 let return_type = func. return_type ( ) ?;
@@ -1244,97 +1232,115 @@ impl ArrayAggDesc {
12441232 }
12451233
12461234 fn create_evaluator ( & self ) -> ArrayAggEvaluator {
1247- ArrayAggEvaluator :: new ( self . func . clone ( ) , self . state_layout . clone ( ) )
1235+ ArrayAggEvaluator :: new ( & self . func , & self . state_layout )
12481236 }
12491237}
12501238
12511239struct ArrayAggFunctionImpl {
12521240 desc : Option < ArrayAggDesc > ,
1253- is_count : bool ,
1241+ return_type : DataType ,
12541242}
12551243
12561244impl ArrayAggFunctionImpl {
1257- fn try_create ( name : & ' static str , arg_type : & DataType ) -> Option < Self > {
1258- let desc = match arg_type {
1259- DataType :: Nullable ( box DataType :: EmptyArray ) | DataType :: EmptyArray => None ,
1260- DataType :: Nullable ( box DataType :: Array ( inner) ) | DataType :: Array ( inner) => {
1261- ArrayAggDesc :: try_create ( name, inner) . ok ( )
1262- }
1263- DataType :: Nullable ( box DataType :: Variant ) | DataType :: Variant => {
1264- ArrayAggDesc :: try_create ( name, & DataType :: Variant ) . ok ( )
1245+ fn new ( name : & ' static str , arg_type : & DataType ) -> Option < Self > {
1246+ let ( desc, return_type) = match arg_type {
1247+ DataType :: Nullable ( box DataType :: EmptyArray ) | DataType :: EmptyArray => (
1248+ None ,
1249+ if name == "count" {
1250+ UInt64Type :: data_type ( )
1251+ } else {
1252+ DataType :: Null
1253+ } ,
1254+ ) ,
1255+ DataType :: Nullable ( box DataType :: Array ( box array_type) )
1256+ | DataType :: Array ( box array_type)
1257+ | DataType :: Nullable ( box array_type @ DataType :: Variant )
1258+ | array_type @ DataType :: Variant => {
1259+ let desc = ArrayAggDesc :: new ( name, array_type) . ok ( ) ?;
1260+ let return_type = desc. return_type . clone ( ) ;
1261+ ( Some ( desc) , return_type)
12651262 }
12661263 _ => return None ,
12671264 } ;
12681265 Some ( Self {
12691266 desc,
1270- is_count : name == "count" ,
1267+ return_type : if arg_type. is_nullable ( ) {
1268+ return_type. wrap_nullable ( )
1269+ } else {
1270+ return_type
1271+ } ,
12711272 } )
12721273 }
12731274
1274- fn return_type ( & self ) -> Cow < DataType > {
1275- match & self . desc {
1276- Some ( desc) => Cow :: Borrowed ( & desc. return_type ) ,
1277- None => {
1278- let data_type = if self . is_count {
1279- DataType :: Number ( NumberDataType :: UInt64 )
1280- } else {
1281- DataType :: Null
1282- } ;
1283- Cow :: Owned ( data_type)
1284- }
1285- }
1286- }
1287-
12881275 fn eval ( & self , args : & [ Value < AnyType > ] , ctx : & mut EvalContext ) -> Value < AnyType > {
1289- match & self . desc {
1290- None => match args {
1291- [ _] => Value :: Scalar ( Scalar :: default_value ( & self . return_type ( ) ) ) ,
1276+ let Some ( desc ) = & self . desc else {
1277+ return match args {
1278+ [ _] => Value :: Scalar ( Scalar :: default_value ( & self . return_type ) ) ,
12921279 _ => unreachable ! ( ) ,
1293- } ,
1294- Some ( desc) => match args {
1295- [ Value :: Scalar ( scalar) ] => match scalar {
1296- Scalar :: EmptyArray | Scalar :: Null => {
1297- Value :: Scalar ( Scalar :: default_value ( & self . return_type ( ) ) )
1298- }
1299- Scalar :: Array ( _) | Scalar :: Variant ( _) => {
1300- scalar_to_array_column ( scalar. as_ref ( ) )
1301- . and_then ( |col| {
1302- let mut evaluator = desc. create_evaluator ( ) ;
1303- let mut result_builder =
1304- ColumnBuilder :: with_capacity ( & desc. return_type , 1 ) ;
1305-
1306- evaluator. eval_column ( col, & mut result_builder) ?;
1307- Ok ( Value :: Scalar ( result_builder. build_scalar ( ) ) )
1308- } )
1309- . unwrap_or_else ( |err| {
1310- ctx. set_error ( 0 , err. to_string ( ) ) ;
1311- Value :: Scalar ( Scalar :: default_value ( & desc. return_type ) )
1312- } )
1280+ } ;
1281+ } ;
1282+
1283+ match args {
1284+ [ Value :: Scalar ( Scalar :: Null | Scalar :: EmptyArray ) ] => {
1285+ Value :: Scalar ( Scalar :: default_value ( & self . return_type ) )
1286+ }
1287+ [ Value :: Scalar ( scalar @ Scalar :: Array ( _) | scalar @ Scalar :: Variant ( _) ) ] => {
1288+ scalar_to_array_column ( scalar. as_ref ( ) )
1289+ . and_then ( |col| {
1290+ let mut evaluator = desc. create_evaluator ( ) ;
1291+ let mut result_builder = ColumnBuilder :: with_capacity ( & desc. return_type , 1 ) ;
1292+
1293+ evaluator. eval_column ( col, & mut result_builder) ?;
1294+ Ok ( Value :: Scalar ( result_builder. build_scalar ( ) ) )
1295+ } )
1296+ . unwrap_or_else ( |err| {
1297+ ctx. set_error ( 0 , err. to_string ( ) ) ;
1298+ Value :: Scalar ( Scalar :: default_value ( & desc. return_type ) )
1299+ } )
1300+ }
1301+ [ Value :: Scalar ( _) ] => unreachable ! ( ) ,
1302+ [ Value :: Column ( column) ] => {
1303+ let mut builder = ColumnBuilder :: with_capacity ( & self . return_type , column. len ( ) ) ;
1304+ let mut evaluator = desc. create_evaluator ( ) ;
1305+ let is_wrap_nullable = desc. return_type != self . return_type ;
1306+ for ( row_index, scalar) in column. iter ( ) . enumerate ( ) {
1307+ if scalar == ScalarRef :: Null {
1308+ builder. push_default ( ) ;
1309+ continue ;
13131310 }
1314- _ => unreachable ! ( ) ,
1315- } ,
1316- [ Value :: Column ( column) ] => {
1317- let mut builder =
1318- ColumnBuilder :: with_capacity ( & self . return_type ( ) , column. len ( ) ) ;
1319- let mut evaluator = desc. create_evaluator ( ) ;
1320- for ( row_index, scalar) in column. iter ( ) . enumerate ( ) {
1321- if scalar == ScalarRef :: Null {
1311+
1312+ let col = match scalar_to_array_column ( scalar) {
1313+ Ok ( col) => col,
1314+ Err ( err) => {
1315+ ctx. set_error ( row_index, err. to_string ( ) ) ;
13221316 builder. push_default ( ) ;
13231317 continue ;
13241318 }
1325- if let Err ( err) = scalar_to_array_column ( scalar)
1326- . and_then ( |col| evaluator. eval_column ( col, & mut builder) )
1327- {
1319+ } ;
1320+
1321+ if !is_wrap_nullable {
1322+ if let Err ( err) = evaluator. eval_column ( col, & mut builder) {
13281323 ctx. set_error ( row_index, err. to_string ( ) ) ;
13291324 if builder. len ( ) == row_index {
13301325 builder. push_default ( ) ;
13311326 }
13321327 }
1328+ } else {
1329+ let ColumnBuilder :: Nullable ( box nullable) = & mut builder else {
1330+ unreachable ! ( )
1331+ } ;
1332+ nullable. validity . push ( true ) ;
1333+ if let Err ( err) = evaluator. eval_column ( col, & mut nullable. builder ) {
1334+ ctx. set_error ( row_index, err. to_string ( ) ) ;
1335+ if nullable. builder . len ( ) == row_index {
1336+ nullable. builder . push_default ( ) ;
1337+ }
1338+ }
13331339 }
1334- Value :: Column ( builder. build ( ) )
13351340 }
1336- _ => unreachable ! ( ) ,
1337- } ,
1341+ Value :: Column ( builder. build ( ) )
1342+ }
1343+ _ => unreachable ! ( ) ,
13381344 }
13391345 }
13401346}
@@ -1363,8 +1369,8 @@ fn register_array_aggr(registry: &mut FunctionRegistry) {
13631369 let [ arg] = args_type else {
13641370 return None ;
13651371 } ;
1366- let impl_info = ArrayAggFunctionImpl :: try_create ( name, arg) ?;
1367- let return_type = impl_info. return_type ( ) . into_owned ( ) ;
1372+ let impl_info = ArrayAggFunctionImpl :: new ( name, arg) ?;
1373+ let return_type = impl_info. return_type . clone ( ) ;
13681374 Some ( Arc :: new ( Function {
13691375 signature : FunctionSignature {
13701376 name : fn_name. to_string ( ) ,
0 commit comments