1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- use std:: borrow:: Cow ;
1615use std:: hash:: Hash ;
1716use std:: ops:: Range ;
1817use std:: sync:: Arc ;
@@ -1160,65 +1159,55 @@ pub fn register(registry: &mut FunctionRegistry) {
11601159 ) ;
11611160}
11621161
1163- struct ArrayAggEvaluator {
1164- func : AggregateFunctionRef ,
1165- state_layout : Arc < StatesLayout > ,
1162+ struct ArrayAggEvaluator < ' a > {
1163+ func : & ' a AggregateFunctionRef ,
1164+ state_layout : & ' a StatesLayout ,
11661165 addr : StateAddr ,
11671166 need_manual_drop_state : bool ,
11681167 _arena : Bump ,
11691168}
11701169
1171- impl ArrayAggEvaluator {
1172- fn new ( func : AggregateFunctionRef , state_layout : Arc < StatesLayout > ) -> Self {
1173- let _arena = Bump :: new ( ) ;
1174- let addr = _arena . alloc_layout ( state_layout. layout ) . into ( ) ;
1175- let mut evaluator = Self {
1176- func ,
1170+ impl < ' a > ArrayAggEvaluator < ' a > {
1171+ fn new ( func : & ' a AggregateFunctionRef , state_layout : & ' a StatesLayout ) -> Self {
1172+ let arena = Bump :: new ( ) ;
1173+ let addr = arena . alloc_layout ( state_layout. layout ) . into ( ) ;
1174+ func . init_state ( AggrState :: new ( addr , & state_layout . states_loc [ 0 ] ) ) ;
1175+ Self {
11771176 state_layout,
11781177 addr,
1179- need_manual_drop_state : false ,
1180- _arena,
1181- } ;
1182- let state = evaluator. state ( ) ;
1183- evaluator. func . init_state ( state) ;
1184- evaluator. need_manual_drop_state = evaluator. func . need_manual_drop_state ( ) ;
1185- evaluator
1178+ need_manual_drop_state : func. need_manual_drop_state ( ) ,
1179+ func,
1180+ _arena : arena,
1181+ }
11861182 }
11871183
11881184 fn state ( & self ) -> AggrState {
11891185 AggrState :: new ( self . addr , & self . state_layout . states_loc [ 0 ] )
11901186 }
11911187
1192- fn reset_state ( & mut self ) {
1188+ fn eval ( & mut self , entry : BlockEntry , builder : & mut ColumnBuilder ) -> Result < ( ) > {
11931189 let state = self . state ( ) ;
11941190 if self . need_manual_drop_state {
11951191 unsafe {
11961192 self . func . drop_state ( state) ;
11971193 }
11981194 }
11991195 self . func . init_state ( state) ;
1200- }
1201-
1202- fn eval_column ( & mut self , column : Column , builder : & mut ColumnBuilder ) -> Result < ( ) > {
1203- self . reset_state ( ) ;
1204- let rows = column. len ( ) ;
1205- let entries = [ BlockEntry :: Column ( column) ] ;
1206- self . func
1207- . accumulate ( self . state ( ) , ( & entries) . into ( ) , None , rows) ?;
1208- self . func . merge_result ( self . state ( ) , false , builder) ?;
1196+ let rows = entry. len ( ) ;
1197+ let entries = & [ entry] ;
1198+ self . func . accumulate ( state, entries. into ( ) , None , rows) ?;
1199+ self . func . merge_result ( state, false , builder) ?;
12091200 Ok ( ( ) )
12101201 }
12111202}
12121203
1213- impl Drop for ArrayAggEvaluator {
1204+ impl Drop for ArrayAggEvaluator < ' _ > {
12141205 fn drop ( & mut self ) {
1215- let need_drop = self . need_manual_drop_state ;
1216- drop_guard ( move || {
1217- if need_drop {
1218- unsafe {
1219- self . func . drop_state ( self . state ( ) ) ;
1220- }
1221- }
1206+ if !self . need_manual_drop_state {
1207+ return ;
1208+ }
1209+ drop_guard ( move || unsafe {
1210+ self . func . drop_state ( self . state ( ) ) ;
12221211 } )
12231212 }
12241213}
@@ -1230,7 +1219,7 @@ struct ArrayAggDesc {
12301219}
12311220
12321221impl ArrayAggDesc {
1233- fn try_create ( name : & str , array_type : & DataType ) -> Result < Self > {
1222+ fn new ( name : & str , array_type : & DataType ) -> Result < Self > {
12341223 let factory = AggregateFunctionFactory :: instance ( ) ;
12351224 let func = factory. get ( name, vec ! [ ] , vec ! [ array_type. clone( ) ] , vec ! [ ] ) ?;
12361225 let return_type = func. return_type ( ) ?;
@@ -1244,97 +1233,133 @@ impl ArrayAggDesc {
12441233 }
12451234
12461235 fn create_evaluator ( & self ) -> ArrayAggEvaluator {
1247- ArrayAggEvaluator :: new ( self . func . clone ( ) , self . state_layout . clone ( ) )
1236+ ArrayAggEvaluator :: new ( & self . func , & self . state_layout )
12481237 }
12491238}
12501239
12511240struct ArrayAggFunctionImpl {
12521241 desc : Option < ArrayAggDesc > ,
1253- is_count : bool ,
1242+ return_type : DataType ,
12541243}
12551244
12561245impl ArrayAggFunctionImpl {
1257- fn try_create ( name : & ' static str , arg_type : & DataType ) -> Option < Self > {
1258- let desc = match arg_type {
1259- DataType :: Nullable ( box DataType :: EmptyArray ) | DataType :: EmptyArray => None ,
1260- DataType :: Nullable ( box DataType :: Array ( inner) ) | DataType :: Array ( inner) => {
1261- ArrayAggDesc :: try_create ( name, inner) . ok ( )
1262- }
1263- DataType :: Nullable ( box DataType :: Variant ) | DataType :: Variant => {
1264- ArrayAggDesc :: try_create ( name, & DataType :: Variant ) . ok ( )
1246+ fn new ( name : & ' static str , arg_type : & DataType ) -> Option < Self > {
1247+ let ( desc, return_type) = match arg_type {
1248+ DataType :: Nullable ( box DataType :: EmptyArray ) | DataType :: EmptyArray => (
1249+ None ,
1250+ if name == "count" {
1251+ UInt64Type :: data_type ( )
1252+ } else {
1253+ DataType :: Null
1254+ } ,
1255+ ) ,
1256+ DataType :: Nullable ( box DataType :: Array ( box array_type) )
1257+ | DataType :: Array ( box array_type)
1258+ | DataType :: Nullable ( box array_type @ DataType :: Variant )
1259+ | array_type @ DataType :: Variant => {
1260+ let desc = ArrayAggDesc :: new ( name, array_type) . ok ( ) ?;
1261+ let return_type = desc. return_type . clone ( ) ;
1262+ ( Some ( desc) , return_type)
12651263 }
12661264 _ => return None ,
12671265 } ;
12681266 Some ( Self {
12691267 desc,
1270- is_count : name == "count" ,
1268+ return_type : if arg_type. is_nullable ( ) {
1269+ return_type. wrap_nullable ( )
1270+ } else {
1271+ return_type
1272+ } ,
12711273 } )
12721274 }
12731275
1274- fn return_type ( & self ) -> Cow < DataType > {
1275- match & self . desc {
1276- Some ( desc) => Cow :: Borrowed ( & desc. return_type ) ,
1277- None => {
1278- let data_type = if self . is_count {
1279- DataType :: Number ( NumberDataType :: UInt64 )
1280- } else {
1281- DataType :: Null
1282- } ;
1283- Cow :: Owned ( data_type)
1284- }
1285- }
1286- }
1287-
12881276 fn eval ( & self , args : & [ Value < AnyType > ] , ctx : & mut EvalContext ) -> Value < AnyType > {
1289- match & self . desc {
1290- None => match args {
1291- [ _] => Value :: Scalar ( Scalar :: default_value ( & self . return_type ( ) ) ) ,
1277+ let Some ( desc ) = & self . desc else {
1278+ return match args {
1279+ [ _] => Value :: Scalar ( Scalar :: default_value ( & self . return_type ) ) ,
12921280 _ => unreachable ! ( ) ,
1293- } ,
1294- Some ( desc) => match args {
1295- [ Value :: Scalar ( scalar) ] => match scalar {
1296- Scalar :: EmptyArray | Scalar :: Null => {
1297- Value :: Scalar ( Scalar :: default_value ( & self . return_type ( ) ) )
1281+ } ;
1282+ } ;
1283+
1284+ match args {
1285+ [ Value :: Scalar ( Scalar :: Null | Scalar :: EmptyArray ) ] => {
1286+ Value :: Scalar ( Scalar :: default_value ( & self . return_type ) )
1287+ }
1288+ [ Value :: Scalar ( scalar @ Scalar :: Array ( _) | scalar @ Scalar :: Variant ( _) ) ] => {
1289+ scalar_to_array_column ( scalar. as_ref ( ) )
1290+ . and_then ( |col| {
1291+ let mut evaluator = desc. create_evaluator ( ) ;
1292+ let mut builder = ColumnBuilder :: with_capacity ( & desc. return_type , 1 ) ;
1293+ evaluator. eval ( col. into ( ) , & mut builder) ?;
1294+ Ok ( Value :: Scalar ( builder. build_scalar ( ) ) )
1295+ } )
1296+ . unwrap_or_else ( |err| {
1297+ ctx. set_error ( 0 , err. to_string ( ) ) ;
1298+ Value :: Scalar ( Scalar :: default_value ( & self . return_type ) )
1299+ } )
1300+ }
1301+ [ Value :: Scalar ( _) ] => unreachable ! ( ) ,
1302+ [ Value :: Column ( Column :: Nullable ( box column) ) ]
1303+ if desc. return_type != self . return_type =>
1304+ {
1305+ let mut builder = ColumnBuilder :: with_capacity ( & self . return_type , column. len ( ) ) ;
1306+ let mut evaluator = desc. create_evaluator ( ) ;
1307+ let ColumnBuilder :: Nullable ( box nullable) = & mut builder else {
1308+ unreachable ! ( )
1309+ } ;
1310+ for ( row_index, scalar) in column. iter ( ) . enumerate ( ) {
1311+ let Some ( scalar) = scalar else {
1312+ nullable. push_null ( ) ;
1313+ continue ;
1314+ } ;
1315+
1316+ let col = match scalar_to_array_column ( scalar) {
1317+ Ok ( col) => col,
1318+ Err ( err) => {
1319+ ctx. set_error ( row_index, err. to_string ( ) ) ;
1320+ nullable. push_null ( ) ;
1321+ continue ;
1322+ }
1323+ } ;
1324+
1325+ nullable. validity . push ( true ) ;
1326+ if let Err ( err) = evaluator. eval ( col. into ( ) , & mut nullable. builder ) {
1327+ ctx. set_error ( row_index, err. to_string ( ) ) ;
1328+ if nullable. builder . len ( ) == row_index {
1329+ nullable. builder . push_default ( ) ;
1330+ }
12981331 }
1299- Scalar :: Array ( _) | Scalar :: Variant ( _) => {
1300- scalar_to_array_column ( scalar. as_ref ( ) )
1301- . and_then ( |col| {
1302- let mut evaluator = desc. create_evaluator ( ) ;
1303- let mut result_builder =
1304- ColumnBuilder :: with_capacity ( & desc. return_type , 1 ) ;
1305-
1306- evaluator. eval_column ( col, & mut result_builder) ?;
1307- Ok ( Value :: Scalar ( result_builder. build_scalar ( ) ) )
1308- } )
1309- . unwrap_or_else ( |err| {
1310- ctx. set_error ( 0 , err. to_string ( ) ) ;
1311- Value :: Scalar ( Scalar :: default_value ( & desc. return_type ) )
1312- } )
1332+ }
1333+ Value :: Column ( builder. build ( ) )
1334+ }
1335+ [ Value :: Column ( column) ] => {
1336+ let mut builder = ColumnBuilder :: with_capacity ( & self . return_type , column. len ( ) ) ;
1337+ let mut evaluator = desc. create_evaluator ( ) ;
1338+ for ( row_index, scalar) in column. iter ( ) . enumerate ( ) {
1339+ if scalar == ScalarRef :: Null {
1340+ builder. push_default ( ) ;
1341+ continue ;
13131342 }
1314- _ => unreachable ! ( ) ,
1315- } ,
1316- [ Value :: Column ( column) ] => {
1317- let mut builder =
1318- ColumnBuilder :: with_capacity ( & self . return_type ( ) , column. len ( ) ) ;
1319- let mut evaluator = desc. create_evaluator ( ) ;
1320- for ( row_index, scalar) in column. iter ( ) . enumerate ( ) {
1321- if scalar == ScalarRef :: Null {
1343+
1344+ let col = match scalar_to_array_column ( scalar) {
1345+ Ok ( col) => col,
1346+ Err ( err) => {
1347+ ctx. set_error ( row_index, err. to_string ( ) ) ;
13221348 builder. push_default ( ) ;
13231349 continue ;
13241350 }
1325- if let Err ( err) = scalar_to_array_column ( scalar)
1326- . and_then ( |col| evaluator. eval_column ( col, & mut builder) )
1327- {
1328- ctx. set_error ( row_index, err. to_string ( ) ) ;
1329- if builder. len ( ) == row_index {
1330- builder. push_default ( ) ;
1331- }
1351+ } ;
1352+
1353+ if let Err ( err) = evaluator. eval ( col. into ( ) , & mut builder) {
1354+ ctx. set_error ( row_index, err. to_string ( ) ) ;
1355+ if builder. len ( ) == row_index {
1356+ builder. push_default ( ) ;
13321357 }
13331358 }
1334- Value :: Column ( builder. build ( ) )
13351359 }
1336- _ => unreachable ! ( ) ,
1337- } ,
1360+ Value :: Column ( builder. build ( ) )
1361+ }
1362+ _ => unreachable ! ( ) ,
13381363 }
13391364 }
13401365}
@@ -1363,8 +1388,8 @@ fn register_array_aggr(registry: &mut FunctionRegistry) {
13631388 let [ arg] = args_type else {
13641389 return None ;
13651390 } ;
1366- let impl_info = ArrayAggFunctionImpl :: try_create ( name, arg) ?;
1367- let return_type = impl_info. return_type ( ) . into_owned ( ) ;
1391+ let impl_info = ArrayAggFunctionImpl :: new ( name, arg) ?;
1392+ let return_type = impl_info. return_type . clone ( ) ;
13681393 Some ( Arc :: new ( Function {
13691394 signature : FunctionSignature {
13701395 name : fn_name. to_string ( ) ,
0 commit comments