Skip to content

Commit 7b2f9c4

Browse files
committed
fix
1 parent 9d84833 commit 7b2f9c4

File tree

3 files changed

+204
-103
lines changed

3 files changed

+204
-103
lines changed

src/query/functions/src/scalars/array.rs

Lines changed: 128 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use std::borrow::Cow;
1615
use std::hash::Hash;
1716
use std::ops::Range;
1817
use std::sync::Arc;
@@ -1160,65 +1159,55 @@ pub fn register(registry: &mut FunctionRegistry) {
11601159
);
11611160
}
11621161

1163-
struct ArrayAggEvaluator {
1164-
func: AggregateFunctionRef,
1165-
state_layout: Arc<StatesLayout>,
1162+
struct ArrayAggEvaluator<'a> {
1163+
func: &'a AggregateFunctionRef,
1164+
state_layout: &'a StatesLayout,
11661165
addr: StateAddr,
11671166
need_manual_drop_state: bool,
11681167
_arena: Bump,
11691168
}
11701169

1171-
impl ArrayAggEvaluator {
1172-
fn new(func: AggregateFunctionRef, state_layout: Arc<StatesLayout>) -> Self {
1173-
let _arena = Bump::new();
1174-
let addr = _arena.alloc_layout(state_layout.layout).into();
1175-
let mut evaluator = Self {
1176-
func,
1170+
impl<'a> ArrayAggEvaluator<'a> {
1171+
fn new(func: &'a AggregateFunctionRef, state_layout: &'a StatesLayout) -> Self {
1172+
let arena = Bump::new();
1173+
let addr = arena.alloc_layout(state_layout.layout).into();
1174+
func.init_state(AggrState::new(addr, &state_layout.states_loc[0]));
1175+
Self {
11771176
state_layout,
11781177
addr,
1179-
need_manual_drop_state: false,
1180-
_arena,
1181-
};
1182-
let state = evaluator.state();
1183-
evaluator.func.init_state(state);
1184-
evaluator.need_manual_drop_state = evaluator.func.need_manual_drop_state();
1185-
evaluator
1178+
need_manual_drop_state: func.need_manual_drop_state(),
1179+
func,
1180+
_arena: arena,
1181+
}
11861182
}
11871183

11881184
fn state(&self) -> AggrState {
11891185
AggrState::new(self.addr, &self.state_layout.states_loc[0])
11901186
}
11911187

1192-
fn reset_state(&mut self) {
1188+
fn eval(&mut self, entry: BlockEntry, builder: &mut ColumnBuilder) -> Result<()> {
11931189
let state = self.state();
11941190
if self.need_manual_drop_state {
11951191
unsafe {
11961192
self.func.drop_state(state);
11971193
}
11981194
}
11991195
self.func.init_state(state);
1200-
}
1201-
1202-
fn eval_column(&mut self, column: Column, builder: &mut ColumnBuilder) -> Result<()> {
1203-
self.reset_state();
1204-
let rows = column.len();
1205-
let entries = [BlockEntry::Column(column)];
1206-
self.func
1207-
.accumulate(self.state(), (&entries).into(), None, rows)?;
1208-
self.func.merge_result(self.state(), false, builder)?;
1196+
let rows = entry.len();
1197+
let entries = &[entry];
1198+
self.func.accumulate(state, entries.into(), None, rows)?;
1199+
self.func.merge_result(state, false, builder)?;
12091200
Ok(())
12101201
}
12111202
}
12121203

1213-
impl Drop for ArrayAggEvaluator {
1204+
impl Drop for ArrayAggEvaluator<'_> {
12141205
fn drop(&mut self) {
1215-
let need_drop = self.need_manual_drop_state;
1216-
drop_guard(move || {
1217-
if need_drop {
1218-
unsafe {
1219-
self.func.drop_state(self.state());
1220-
}
1221-
}
1206+
if !self.need_manual_drop_state {
1207+
return;
1208+
}
1209+
drop_guard(move || unsafe {
1210+
self.func.drop_state(self.state());
12221211
})
12231212
}
12241213
}
@@ -1230,7 +1219,7 @@ struct ArrayAggDesc {
12301219
}
12311220

12321221
impl ArrayAggDesc {
1233-
fn try_create(name: &str, array_type: &DataType) -> Result<Self> {
1222+
fn new(name: &str, array_type: &DataType) -> Result<Self> {
12341223
let factory = AggregateFunctionFactory::instance();
12351224
let func = factory.get(name, vec![], vec![array_type.clone()], vec![])?;
12361225
let return_type = func.return_type()?;
@@ -1244,97 +1233,133 @@ impl ArrayAggDesc {
12441233
}
12451234

12461235
fn create_evaluator(&self) -> ArrayAggEvaluator {
1247-
ArrayAggEvaluator::new(self.func.clone(), self.state_layout.clone())
1236+
ArrayAggEvaluator::new(&self.func, &self.state_layout)
12481237
}
12491238
}
12501239

12511240
struct ArrayAggFunctionImpl {
12521241
desc: Option<ArrayAggDesc>,
1253-
is_count: bool,
1242+
return_type: DataType,
12541243
}
12551244

12561245
impl ArrayAggFunctionImpl {
1257-
fn try_create(name: &'static str, arg_type: &DataType) -> Option<Self> {
1258-
let desc = match arg_type {
1259-
DataType::Nullable(box DataType::EmptyArray) | DataType::EmptyArray => None,
1260-
DataType::Nullable(box DataType::Array(inner)) | DataType::Array(inner) => {
1261-
ArrayAggDesc::try_create(name, inner).ok()
1262-
}
1263-
DataType::Nullable(box DataType::Variant) | DataType::Variant => {
1264-
ArrayAggDesc::try_create(name, &DataType::Variant).ok()
1246+
fn new(name: &'static str, arg_type: &DataType) -> Option<Self> {
1247+
let (desc, return_type) = match arg_type {
1248+
DataType::Nullable(box DataType::EmptyArray) | DataType::EmptyArray => (
1249+
None,
1250+
if name == "count" {
1251+
UInt64Type::data_type()
1252+
} else {
1253+
DataType::Null
1254+
},
1255+
),
1256+
DataType::Nullable(box DataType::Array(box array_type))
1257+
| DataType::Array(box array_type)
1258+
| DataType::Nullable(box array_type @ DataType::Variant)
1259+
| array_type @ DataType::Variant => {
1260+
let desc = ArrayAggDesc::new(name, array_type).ok()?;
1261+
let return_type = desc.return_type.clone();
1262+
(Some(desc), return_type)
12651263
}
12661264
_ => return None,
12671265
};
12681266
Some(Self {
12691267
desc,
1270-
is_count: name == "count",
1268+
return_type: if arg_type.is_nullable() {
1269+
return_type.wrap_nullable()
1270+
} else {
1271+
return_type
1272+
},
12711273
})
12721274
}
12731275

1274-
fn return_type(&self) -> Cow<DataType> {
1275-
match &self.desc {
1276-
Some(desc) => Cow::Borrowed(&desc.return_type),
1277-
None => {
1278-
let data_type = if self.is_count {
1279-
DataType::Number(NumberDataType::UInt64)
1280-
} else {
1281-
DataType::Null
1282-
};
1283-
Cow::Owned(data_type)
1284-
}
1285-
}
1286-
}
1287-
12881276
fn eval(&self, args: &[Value<AnyType>], ctx: &mut EvalContext) -> Value<AnyType> {
1289-
match &self.desc {
1290-
None => match args {
1291-
[_] => Value::Scalar(Scalar::default_value(&self.return_type())),
1277+
let Some(desc) = &self.desc else {
1278+
return match args {
1279+
[_] => Value::Scalar(Scalar::default_value(&self.return_type)),
12921280
_ => unreachable!(),
1293-
},
1294-
Some(desc) => match args {
1295-
[Value::Scalar(scalar)] => match scalar {
1296-
Scalar::EmptyArray | Scalar::Null => {
1297-
Value::Scalar(Scalar::default_value(&self.return_type()))
1281+
};
1282+
};
1283+
1284+
match args {
1285+
[Value::Scalar(Scalar::Null | Scalar::EmptyArray)] => {
1286+
Value::Scalar(Scalar::default_value(&self.return_type))
1287+
}
1288+
[Value::Scalar(scalar @ Scalar::Array(_) | scalar @ Scalar::Variant(_))] => {
1289+
scalar_to_array_column(scalar.as_ref())
1290+
.and_then(|col| {
1291+
let mut evaluator = desc.create_evaluator();
1292+
let mut builder = ColumnBuilder::with_capacity(&desc.return_type, 1);
1293+
evaluator.eval(col.into(), &mut builder)?;
1294+
Ok(Value::Scalar(builder.build_scalar()))
1295+
})
1296+
.unwrap_or_else(|err| {
1297+
ctx.set_error(0, err.to_string());
1298+
Value::Scalar(Scalar::default_value(&self.return_type))
1299+
})
1300+
}
1301+
[Value::Scalar(_)] => unreachable!(),
1302+
[Value::Column(Column::Nullable(box column))]
1303+
if desc.return_type != self.return_type =>
1304+
{
1305+
let mut builder = ColumnBuilder::with_capacity(&self.return_type, column.len());
1306+
let mut evaluator = desc.create_evaluator();
1307+
let ColumnBuilder::Nullable(box nullable) = &mut builder else {
1308+
unreachable!()
1309+
};
1310+
for (row_index, scalar) in column.iter().enumerate() {
1311+
let Some(scalar) = scalar else {
1312+
nullable.push_null();
1313+
continue;
1314+
};
1315+
1316+
let col = match scalar_to_array_column(scalar) {
1317+
Ok(col) => col,
1318+
Err(err) => {
1319+
ctx.set_error(row_index, err.to_string());
1320+
nullable.push_null();
1321+
continue;
1322+
}
1323+
};
1324+
1325+
nullable.validity.push(true);
1326+
if let Err(err) = evaluator.eval(col.into(), &mut nullable.builder) {
1327+
ctx.set_error(row_index, err.to_string());
1328+
if nullable.builder.len() == row_index {
1329+
nullable.builder.push_default();
1330+
}
12981331
}
1299-
Scalar::Array(_) | Scalar::Variant(_) => {
1300-
scalar_to_array_column(scalar.as_ref())
1301-
.and_then(|col| {
1302-
let mut evaluator = desc.create_evaluator();
1303-
let mut result_builder =
1304-
ColumnBuilder::with_capacity(&desc.return_type, 1);
1305-
1306-
evaluator.eval_column(col, &mut result_builder)?;
1307-
Ok(Value::Scalar(result_builder.build_scalar()))
1308-
})
1309-
.unwrap_or_else(|err| {
1310-
ctx.set_error(0, err.to_string());
1311-
Value::Scalar(Scalar::default_value(&desc.return_type))
1312-
})
1332+
}
1333+
Value::Column(builder.build())
1334+
}
1335+
[Value::Column(column)] => {
1336+
let mut builder = ColumnBuilder::with_capacity(&self.return_type, column.len());
1337+
let mut evaluator = desc.create_evaluator();
1338+
for (row_index, scalar) in column.iter().enumerate() {
1339+
if scalar == ScalarRef::Null {
1340+
builder.push_default();
1341+
continue;
13131342
}
1314-
_ => unreachable!(),
1315-
},
1316-
[Value::Column(column)] => {
1317-
let mut builder =
1318-
ColumnBuilder::with_capacity(&self.return_type(), column.len());
1319-
let mut evaluator = desc.create_evaluator();
1320-
for (row_index, scalar) in column.iter().enumerate() {
1321-
if scalar == ScalarRef::Null {
1343+
1344+
let col = match scalar_to_array_column(scalar) {
1345+
Ok(col) => col,
1346+
Err(err) => {
1347+
ctx.set_error(row_index, err.to_string());
13221348
builder.push_default();
13231349
continue;
13241350
}
1325-
if let Err(err) = scalar_to_array_column(scalar)
1326-
.and_then(|col| evaluator.eval_column(col, &mut builder))
1327-
{
1328-
ctx.set_error(row_index, err.to_string());
1329-
if builder.len() == row_index {
1330-
builder.push_default();
1331-
}
1351+
};
1352+
1353+
if let Err(err) = evaluator.eval(col.into(), &mut builder) {
1354+
ctx.set_error(row_index, err.to_string());
1355+
if builder.len() == row_index {
1356+
builder.push_default();
13321357
}
13331358
}
1334-
Value::Column(builder.build())
13351359
}
1336-
_ => unreachable!(),
1337-
},
1360+
Value::Column(builder.build())
1361+
}
1362+
_ => unreachable!(),
13381363
}
13391364
}
13401365
}
@@ -1363,8 +1388,8 @@ fn register_array_aggr(registry: &mut FunctionRegistry) {
13631388
let [arg] = args_type else {
13641389
return None;
13651390
};
1366-
let impl_info = ArrayAggFunctionImpl::try_create(name, arg)?;
1367-
let return_type = impl_info.return_type().into_owned();
1391+
let impl_info = ArrayAggFunctionImpl::new(name, arg)?;
1392+
let return_type = impl_info.return_type.clone();
13681393
Some(Arc::new(Function {
13691394
signature: FunctionSignature {
13701395
name: fn_name.to_string(),

src/query/functions/tests/it/scalars/array.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
use std::io::Write;
1616

1717
use databend_common_expression::types::*;
18+
use databend_common_expression::ColumnBuilder;
1819
use databend_common_expression::FromData;
20+
use databend_common_expression::Scalar;
1921
use goldenfile::Mint;
2022

2123
use super::run_ast;
@@ -430,6 +432,20 @@ fn test_array_count(file: &mut impl Write) {
430432
("d", Int16Type::from_data(vec![4i16, 8, 1, 9])),
431433
]);
432434

435+
{
436+
let data_type = DataType::Array(Box::new(Int16Type::data_type())).wrap_nullable();
437+
let mut builder = ColumnBuilder::with_capacity(&data_type, 4);
438+
439+
builder.push_default();
440+
builder.push(Scalar::Array(Int16Type::from_data(vec![1, 5, 8, 3])).as_ref());
441+
builder.push(Scalar::Array(Int16Type::from_data(vec![1, 5])).as_ref());
442+
builder.push_default();
443+
444+
let column = builder.build();
445+
446+
run_ast(file, "array_count(a)", &[("a", column)]);
447+
}
448+
433449
run_ast(file, "array_count([a, b, c, d])", &[
434450
(
435451
"a",
@@ -464,6 +480,20 @@ fn test_array_max(file: &mut impl Write) {
464480
run_ast(file, "array_max(['a', 'b', 'c', 'd', 'e'])", &[]);
465481
run_ast(file, "array_max(['a', 'b', NULL, 'c', 'd', NULL])", &[]);
466482

483+
{
484+
let data_type = DataType::Array(Box::new(Int16Type::data_type())).wrap_nullable();
485+
let mut builder = ColumnBuilder::with_capacity(&data_type, 4);
486+
487+
builder.push_default();
488+
builder.push(Scalar::Array(Int16Type::from_data(vec![1, 5, 8, 3])).as_ref());
489+
builder.push(Scalar::Array(Int16Type::from_data(vec![1, 5])).as_ref());
490+
builder.push_default();
491+
492+
let column = builder.build();
493+
494+
run_ast(file, "array_max(a)", &[("a", column)]);
495+
}
496+
467497
run_ast(file, "array_max([a, b, c, d])", &[
468498
("a", Int16Type::from_data(vec![1i16, 5, 8, 3])),
469499
("b", Int16Type::from_data(vec![2i16, 6, 1, 2])),

0 commit comments

Comments
 (0)