Skip to content

Commit 49004bb

Browse files
committed
fix grouped sum index
1 parent 4e6e9ed commit 49004bb

2 files changed

Lines changed: 23 additions & 2 deletions

File tree

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,12 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
229229
)?;
230230
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
231231

232-
for (offset, size) in offsets.iter().zip(sizes.iter()) {
232+
for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() {
233233
let offset = offset.to_usize().vortex_expect("Offset value is not usize");
234234
let size = size.to_usize().vortex_expect("Size value is not usize");
235235

236-
if validity.value(offset) {
236+
// validity is for the outer list view, so it must be indexed with `i`
237+
if validity.value(i) {
237238
let group = elements.slice(offset..offset + size)?;
238239
accumulator.accumulate(&group, ctx)?;
239240
states.append_scalar(&accumulator.flush()?)?;

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ mod tests {
347347
use crate::arrays::ConstantArray;
348348
use crate::arrays::DecimalArray;
349349
use crate::arrays::FixedSizeListArray;
350+
use crate::arrays::ListViewArray;
350351
use crate::arrays::PrimitiveArray;
351352
use crate::assert_arrays_eq;
352353
use crate::dtype::DType;
@@ -616,6 +617,25 @@ mod tests {
616617
Ok(())
617618
}
618619

620+
#[test]
621+
fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> {
622+
let elements =
623+
PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array();
624+
let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array();
625+
let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array();
626+
let validity = Validity::from_iter([true, false, true]);
627+
let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array();
628+
629+
let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
630+
let result = run_grouped_sum(&groups, &elem_dtype)?;
631+
632+
// group 0 -> elements[2..3] = 300; group 1 -> null; group 2 -> elements[1..2] = 200.
633+
let expected =
634+
PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array();
635+
assert_arrays_eq!(&result, &expected);
636+
Ok(())
637+
}
638+
619639
// Chunked array tests
620640

621641
#[test]

0 commit comments

Comments
 (0)