Skip to content

Commit 8e985e7

Browse files
committed
fix
1 parent 9d84833 commit 8e985e7

File tree

1 file changed

+105
-99
lines changed

1 file changed

+105
-99
lines changed

src/query/functions/src/scalars/array.rs

Lines changed: 105 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use std::borrow::Cow;
1615
use std::hash::Hash;
1716
use std::ops::Range;
1817
use std::sync::Arc;
@@ -1160,65 +1159,54 @@ pub fn register(registry: &mut FunctionRegistry) {
11601159
);
11611160
}
11621161

1163-
struct ArrayAggEvaluator {
1164-
func: AggregateFunctionRef,
1165-
state_layout: Arc<StatesLayout>,
1162+
struct ArrayAggEvaluator<'a> {
1163+
func: &'a AggregateFunctionRef,
1164+
state_layout: &'a StatesLayout,
11661165
addr: StateAddr,
11671166
need_manual_drop_state: bool,
11681167
_arena: Bump,
11691168
}
11701169

1171-
impl ArrayAggEvaluator {
1172-
fn new(func: AggregateFunctionRef, state_layout: Arc<StatesLayout>) -> Self {
1173-
let _arena = Bump::new();
1174-
let addr = _arena.alloc_layout(state_layout.layout).into();
1175-
let mut evaluator = Self {
1176-
func,
1170+
impl<'a> ArrayAggEvaluator<'a> {
1171+
fn new(func: &'a AggregateFunctionRef, state_layout: &'a StatesLayout) -> Self {
1172+
let arena = Bump::new();
1173+
let addr = arena.alloc_layout(state_layout.layout).into();
1174+
Self {
11771175
state_layout,
11781176
addr,
1179-
need_manual_drop_state: false,
1180-
_arena,
1181-
};
1182-
let state = evaluator.state();
1183-
evaluator.func.init_state(state);
1184-
evaluator.need_manual_drop_state = evaluator.func.need_manual_drop_state();
1185-
evaluator
1177+
need_manual_drop_state: func.need_manual_drop_state(),
1178+
func,
1179+
_arena: arena,
1180+
}
11861181
}
11871182

11881183
fn state(&self) -> AggrState {
11891184
AggrState::new(self.addr, &self.state_layout.states_loc[0])
11901185
}
11911186

1192-
fn reset_state(&mut self) {
1187+
fn eval_column(&mut self, column: Column, builder: &mut ColumnBuilder) -> Result<()> {
11931188
let state = self.state();
11941189
if self.need_manual_drop_state {
11951190
unsafe {
11961191
self.func.drop_state(state);
11971192
}
11981193
}
11991194
self.func.init_state(state);
1200-
}
1201-
1202-
fn eval_column(&mut self, column: Column, builder: &mut ColumnBuilder) -> Result<()> {
1203-
self.reset_state();
12041195
let rows = column.len();
1205-
let entries = [BlockEntry::Column(column)];
1206-
self.func
1207-
.accumulate(self.state(), (&entries).into(), None, rows)?;
1208-
self.func.merge_result(self.state(), false, builder)?;
1196+
let entries = &[BlockEntry::Column(column)];
1197+
self.func.accumulate(state, entries.into(), None, rows)?;
1198+
self.func.merge_result(state, false, builder)?;
12091199
Ok(())
12101200
}
12111201
}
12121202

1213-
impl Drop for ArrayAggEvaluator {
1203+
impl Drop for ArrayAggEvaluator<'_> {
12141204
fn drop(&mut self) {
1215-
let need_drop = self.need_manual_drop_state;
1216-
drop_guard(move || {
1217-
if need_drop {
1218-
unsafe {
1219-
self.func.drop_state(self.state());
1220-
}
1221-
}
1205+
if !self.need_manual_drop_state {
1206+
return;
1207+
}
1208+
drop_guard(move || unsafe {
1209+
self.func.drop_state(self.state());
12221210
})
12231211
}
12241212
}
@@ -1230,7 +1218,7 @@ struct ArrayAggDesc {
12301218
}
12311219

12321220
impl ArrayAggDesc {
1233-
fn try_create(name: &str, array_type: &DataType) -> Result<Self> {
1221+
fn new(name: &str, array_type: &DataType) -> Result<Self> {
12341222
let factory = AggregateFunctionFactory::instance();
12351223
let func = factory.get(name, vec![], vec![array_type.clone()], vec![])?;
12361224
let return_type = func.return_type()?;
@@ -1244,97 +1232,115 @@ impl ArrayAggDesc {
12441232
}
12451233

12461234
fn create_evaluator(&self) -> ArrayAggEvaluator {
1247-
ArrayAggEvaluator::new(self.func.clone(), self.state_layout.clone())
1235+
ArrayAggEvaluator::new(&self.func, &self.state_layout)
12481236
}
12491237
}
12501238

12511239
struct ArrayAggFunctionImpl {
12521240
desc: Option<ArrayAggDesc>,
1253-
is_count: bool,
1241+
return_type: DataType,
12541242
}
12551243

12561244
impl ArrayAggFunctionImpl {
1257-
fn try_create(name: &'static str, arg_type: &DataType) -> Option<Self> {
1258-
let desc = match arg_type {
1259-
DataType::Nullable(box DataType::EmptyArray) | DataType::EmptyArray => None,
1260-
DataType::Nullable(box DataType::Array(inner)) | DataType::Array(inner) => {
1261-
ArrayAggDesc::try_create(name, inner).ok()
1262-
}
1263-
DataType::Nullable(box DataType::Variant) | DataType::Variant => {
1264-
ArrayAggDesc::try_create(name, &DataType::Variant).ok()
1245+
fn new(name: &'static str, arg_type: &DataType) -> Option<Self> {
1246+
let (desc, return_type) = match arg_type {
1247+
DataType::Nullable(box DataType::EmptyArray) | DataType::EmptyArray => (
1248+
None,
1249+
if name == "count" {
1250+
UInt64Type::data_type()
1251+
} else {
1252+
DataType::Null
1253+
},
1254+
),
1255+
DataType::Nullable(box DataType::Array(box array_type))
1256+
| DataType::Array(box array_type)
1257+
| DataType::Nullable(box array_type @ DataType::Variant)
1258+
| array_type @ DataType::Variant => {
1259+
let desc = ArrayAggDesc::new(name, array_type).ok()?;
1260+
let return_type = desc.return_type.clone();
1261+
(Some(desc), return_type)
12651262
}
12661263
_ => return None,
12671264
};
12681265
Some(Self {
12691266
desc,
1270-
is_count: name == "count",
1267+
return_type: if arg_type.is_nullable() {
1268+
return_type.wrap_nullable()
1269+
} else {
1270+
return_type
1271+
},
12711272
})
12721273
}
12731274

1274-
fn return_type(&self) -> Cow<DataType> {
1275-
match &self.desc {
1276-
Some(desc) => Cow::Borrowed(&desc.return_type),
1277-
None => {
1278-
let data_type = if self.is_count {
1279-
DataType::Number(NumberDataType::UInt64)
1280-
} else {
1281-
DataType::Null
1282-
};
1283-
Cow::Owned(data_type)
1284-
}
1285-
}
1286-
}
1287-
12881275
fn eval(&self, args: &[Value<AnyType>], ctx: &mut EvalContext) -> Value<AnyType> {
1289-
match &self.desc {
1290-
None => match args {
1291-
[_] => Value::Scalar(Scalar::default_value(&self.return_type())),
1276+
let Some(desc) = &self.desc else {
1277+
return match args {
1278+
[_] => Value::Scalar(Scalar::default_value(&self.return_type)),
12921279
_ => unreachable!(),
1293-
},
1294-
Some(desc) => match args {
1295-
[Value::Scalar(scalar)] => match scalar {
1296-
Scalar::EmptyArray | Scalar::Null => {
1297-
Value::Scalar(Scalar::default_value(&self.return_type()))
1298-
}
1299-
Scalar::Array(_) | Scalar::Variant(_) => {
1300-
scalar_to_array_column(scalar.as_ref())
1301-
.and_then(|col| {
1302-
let mut evaluator = desc.create_evaluator();
1303-
let mut result_builder =
1304-
ColumnBuilder::with_capacity(&desc.return_type, 1);
1305-
1306-
evaluator.eval_column(col, &mut result_builder)?;
1307-
Ok(Value::Scalar(result_builder.build_scalar()))
1308-
})
1309-
.unwrap_or_else(|err| {
1310-
ctx.set_error(0, err.to_string());
1311-
Value::Scalar(Scalar::default_value(&desc.return_type))
1312-
})
1280+
};
1281+
};
1282+
1283+
match args {
1284+
[Value::Scalar(Scalar::Null | Scalar::EmptyArray)] => {
1285+
Value::Scalar(Scalar::default_value(&self.return_type))
1286+
}
1287+
[Value::Scalar(scalar @ Scalar::Array(_) | scalar @ Scalar::Variant(_))] => {
1288+
scalar_to_array_column(scalar.as_ref())
1289+
.and_then(|col| {
1290+
let mut evaluator = desc.create_evaluator();
1291+
let mut result_builder = ColumnBuilder::with_capacity(&desc.return_type, 1);
1292+
1293+
evaluator.eval_column(col, &mut result_builder)?;
1294+
Ok(Value::Scalar(result_builder.build_scalar()))
1295+
})
1296+
.unwrap_or_else(|err| {
1297+
ctx.set_error(0, err.to_string());
1298+
Value::Scalar(Scalar::default_value(&desc.return_type))
1299+
})
1300+
}
1301+
[Value::Scalar(_)] => unreachable!(),
1302+
[Value::Column(column)] => {
1303+
let mut builder = ColumnBuilder::with_capacity(&self.return_type, column.len());
1304+
let mut evaluator = desc.create_evaluator();
1305+
let is_wrap_nullable = desc.return_type != self.return_type;
1306+
for (row_index, scalar) in column.iter().enumerate() {
1307+
if scalar == ScalarRef::Null {
1308+
builder.push_default();
1309+
continue;
13131310
}
1314-
_ => unreachable!(),
1315-
},
1316-
[Value::Column(column)] => {
1317-
let mut builder =
1318-
ColumnBuilder::with_capacity(&self.return_type(), column.len());
1319-
let mut evaluator = desc.create_evaluator();
1320-
for (row_index, scalar) in column.iter().enumerate() {
1321-
if scalar == ScalarRef::Null {
1311+
1312+
let col = match scalar_to_array_column(scalar) {
1313+
Ok(col) => col,
1314+
Err(err) => {
1315+
ctx.set_error(row_index, err.to_string());
13221316
builder.push_default();
13231317
continue;
13241318
}
1325-
if let Err(err) = scalar_to_array_column(scalar)
1326-
.and_then(|col| evaluator.eval_column(col, &mut builder))
1327-
{
1319+
};
1320+
1321+
if !is_wrap_nullable {
1322+
if let Err(err) = evaluator.eval_column(col, &mut builder) {
13281323
ctx.set_error(row_index, err.to_string());
13291324
if builder.len() == row_index {
13301325
builder.push_default();
13311326
}
13321327
}
1328+
} else {
1329+
let ColumnBuilder::Nullable(box nullable) = &mut builder else {
1330+
unreachable!()
1331+
};
1332+
nullable.validity.push(true);
1333+
if let Err(err) = evaluator.eval_column(col, &mut nullable.builder) {
1334+
ctx.set_error(row_index, err.to_string());
1335+
if nullable.builder.len() == row_index {
1336+
nullable.builder.push_default();
1337+
}
1338+
}
13331339
}
1334-
Value::Column(builder.build())
13351340
}
1336-
_ => unreachable!(),
1337-
},
1341+
Value::Column(builder.build())
1342+
}
1343+
_ => unreachable!(),
13381344
}
13391345
}
13401346
}
@@ -1363,8 +1369,8 @@ fn register_array_aggr(registry: &mut FunctionRegistry) {
13631369
let [arg] = args_type else {
13641370
return None;
13651371
};
1366-
let impl_info = ArrayAggFunctionImpl::try_create(name, arg)?;
1367-
let return_type = impl_info.return_type().into_owned();
1372+
let impl_info = ArrayAggFunctionImpl::new(name, arg)?;
1373+
let return_type = impl_info.return_type.clone();
13681374
Some(Arc::new(Function {
13691375
signature: FunctionSignature {
13701376
name: fn_name.to_string(),

0 commit comments

Comments
 (0)