diff --git a/vortex-array/src/arrays/listview/compute/mod.rs b/vortex-array/src/arrays/listview/compute/mod.rs index 9a43503c4b5..87587495f7a 100644 --- a/vortex-array/src/arrays/listview/compute/mod.rs +++ b/vortex-array/src/arrays/listview/compute/mod.rs @@ -6,3 +6,4 @@ mod mask; pub(crate) mod rules; mod slice; mod take; +mod zip; diff --git a/vortex-array/src/arrays/listview/compute/zip.rs b/vortex-array/src/arrays/listview/compute/zip.rs new file mode 100644 index 00000000000..7093c3e0daa --- /dev/null +++ b/vortex-array/src/arrays/listview/compute/zip.rs @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::BitAnd; +use std::ops::BitOr; +use std::ops::Not; + +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::Chunked; +use crate::arrays::ChunkedArray; +use crate::arrays::ListView; +use crate::arrays::ListViewArray; +use crate::arrays::chunked::ChunkedArrayExt; +use crate::arrays::listview::ListViewArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar_fn::fns::zip::ZipKernel; +use crate::validity::Validity; + +/// Zip two [`ListViewArray`]s by selecting whole list views per row. +/// +/// A [`ListViewArray`] addresses each list by an `(offset, size)` pair into a shared `elements` +/// array, and unlike [`ListArray`](crate::arrays::ListArray) it does not require lists to be stored +/// contiguously or in order. Zipping two list views is therefore a metadata-only operation over the +/// `offsets`, `sizes` and `validity` child arrays: we concatenate the two `elements` arrays +/// (without rewriting them) and, for each row, select the `(offset, size)` pair from `if_true` or +/// `if_false` per the mask. `if_false` views are shifted past the end of `if_true`'s elements so +/// they continue to address the correct half of the concatenated elements array. +impl ZipKernel for ListView { + fn zip( + if_true: ArrayView<'_, ListView>, + if_false: &ArrayRef, + mask: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(if_false) = if_false.as_opt::() else { + return Ok(None); + }; + + // Null mask entries select `if_false`, matching `Zip`'s SQL ELSE semantics. + let mask = mask.try_to_mask_fill_null_false(ctx)?; + match &mask { + // Defer the trivial masks to the generic zip, which just casts one side. + Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None), + Mask::Values(_) => {} + } + + let len = if_true.len(); + + let result_elements_dtype = if_true + .elements() + .dtype() + .union_nullability(if_false.elements().dtype().nullability()); + + // `if_false`'s elements share the element dtype up to nullability; normalize so both chunks + // of the concatenated elements array have an identical dtype. + let true_elements = if_true.elements().cast(result_elements_dtype.clone())?; + let false_elements = if_false.elements().cast(result_elements_dtype.clone())?; + + // `if_false` views index into the second half of the concatenated elements. + let false_shift = true_elements.len() as u64; + + // Concatenate the two `elements` arrays without copying. If either side is already a + // `ChunkedArray` (e.g. the result of a previous list-view zip), splice its chunks in + // directly rather than nesting chunked arrays. + let mut chunks = Vec::with_capacity(2); + push_element_chunks(true_elements, &mut chunks); + push_element_chunks(false_elements, &mut chunks); + let elements = ChunkedArray::try_new(chunks, result_elements_dtype)?.into_array(); + + let true_offsets = to_u64(if_true.offsets(), ctx)?; + let true_sizes = to_u64(if_true.sizes(), ctx)?; + let false_offsets = to_u64(if_false.offsets(), ctx)?; + let false_sizes = to_u64(if_false.sizes(), ctx)?; + + let mut offsets = BufferMut::::with_capacity(len); + let mut sizes = BufferMut::::with_capacity(len); + for (idx, (out_offsets, out_sizes)) in offsets + .spare_capacity_mut() + .iter_mut() + .zip(sizes.spare_capacity_mut().iter_mut()) + .take(len) + .enumerate() + { + if mask.value(idx) { + out_offsets.write(true_offsets[idx]); + out_sizes.write(true_sizes[idx]); + } else { + out_offsets.write(false_offsets[idx] + false_shift); + out_sizes.write(false_sizes[idx]); + } + } + + // SAFETY: the loop above initialized exactly `len` slots in both buffers. + unsafe { + offsets.set_len(len); + sizes.set_len(len); + } + + let validity = zip_validity(if_true.validity()?, if_false.validity()?, &mask, ctx)?; + + Ok(Some( + ListViewArray::try_new( + elements, + offsets.freeze().into_array(), + sizes.freeze().into_array(), + validity, + )? + .into_array(), + )) + } +} + +/// Appends `array`'s element chunks to `chunks`, flattening a top-level [`ChunkedArray`] so the +/// concatenated elements never nest chunked arrays. +fn push_element_chunks(array: ArrayRef, chunks: &mut Vec) { + match array.as_opt::() { + Some(chunked) => chunks.extend(chunked.iter_chunks().cloned()), + None => chunks.push(array), + } +} + +/// Read a non-nullable integer array into a `u64` buffer. +fn to_u64(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { + array + .clone() + .cast(DType::Primitive(PType::U64, Nullability::NonNullable))? + .execute::>(ctx) +} + +/// Combine the two list-level validities, taking `if_true`'s validity where `mask` is set and +/// `if_false`'s where it is not. +fn zip_validity( + if_true: Validity, + if_false: Validity, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { + Ok(match (&if_true, &if_false) { + (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable, + (Validity::AllValid, Validity::AllValid) => Validity::AllValid, + (Validity::AllInvalid, Validity::AllInvalid) => Validity::AllInvalid, + _ => { + let true_mask = if_true.execute_mask(mask.len(), ctx)?; + let false_mask = if_false.execute_mask(mask.len(), ctx)?; + let combined = true_mask + .bitand(mask) + .bitor(&false_mask.bitand(&mask.not())); + Validity::from_mask(combined, if_true.nullability() | if_false.nullability()) + } + }) +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_mask::Mask; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::Chunked; + use crate::arrays::ChunkedArray; + use crate::arrays::ListView; + use crate::arrays::ListViewArray; + use crate::arrays::chunked::ChunkedArrayExt; + use crate::arrays::listview::ListViewArrayExt; + use crate::assert_arrays_eq; + use crate::builtins::ArrayBuiltins; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::validity::Validity; + + fn list_view( + elements: ArrayRef, + offsets: ArrayRef, + sizes: ArrayRef, + validity: Validity, + ) -> ArrayRef { + ListViewArray::try_new(elements, offsets, sizes, validity) + .unwrap() + .into_array() + } + + /// `zip` of two list views selects whole lists per the mask and keeps the list encoding. + #[test] + fn zip_selects_lists() -> VortexResult<()> { + // [[1, 2], [3], [4, 5, 6]] + let if_true = list_view( + buffer![1i32, 2, 3, 4, 5, 6].into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 3].into_array(), + Validity::NonNullable, + ); + // [[10], [20, 21], [30]] + let if_false = list_view( + buffer![10i32, 20, 21, 30].into_array(), + buffer![0u32, 1, 3].into_array(), + buffer![1u32, 2, 1].into_array(), + Validity::NonNullable, + ); + let mask = Mask::from_iter([true, false, true]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + + // The kernel should keep the list-view encoding rather than canonicalizing. + assert!(result.is::()); + + // Expected: [[1, 2], [20, 21], [4, 5, 6]] + let expected = list_view( + buffer![1i32, 2, 20, 21, 4, 5, 6].into_array(), + buffer![0u32, 2, 4].into_array(), + buffer![2u32, 2, 3].into_array(), + Validity::NonNullable, + ); + assert_arrays_eq!(result, expected); + Ok(()) + } + + /// `zip` selects list-level validity from the chosen side and widens nullability. + #[test] + fn zip_selects_validity() -> VortexResult<()> { + // [[1], null, [2]] (list-level nulls) + let if_true = list_view( + buffer![1i32, 2].into_array(), + buffer![0u32, 1, 1].into_array(), + buffer![1u32, 0, 1].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + // [[10], [20], null] + let if_false = list_view( + buffer![10i32, 20].into_array(), + buffer![0u32, 1, 2].into_array(), + buffer![1u32, 1, 0].into_array(), + Validity::Array(BoolArray::from_iter([true, true, false]).into_array()), + ); + // true -> if_true, false -> if_false + let mask = Mask::from_iter([false, true, true]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + + // Row 0 -> if_false[0] = [10]; row 1 -> if_true[1] = null; row 2 -> if_true[2] = [2] + let expected = list_view( + buffer![10i32, 2].into_array(), + buffer![0u32, 1, 1].into_array(), + buffer![1u32, 0, 1].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + assert_arrays_eq!(result, expected); + Ok(()) + } + + /// `zip` handles out-of-order/non-contiguous offsets and widens nullability when only one side + /// is nullable. + #[test] + fn zip_out_of_order_offsets_and_widening() -> VortexResult<()> { + // [[5, 6], [7], [8, 9]] expressed with out-of-order offsets. + let if_true = list_view( + buffer![7i32, 8, 9, 5, 6].into_array(), + buffer![3u32, 0, 1].into_array(), + buffer![2u32, 1, 2].into_array(), + Validity::NonNullable, + ); + // [[100], null, [200, 201]] + let if_false = list_view( + buffer![100i32, 200, 201].into_array(), + buffer![0u32, 1, 1].into_array(), + buffer![1u32, 0, 2].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + let mask = Mask::from_iter([true, true, false]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + assert!(result.is::()); + + // [[5, 6], [7], [200, 201]], all valid but nullable (widened by if_false). + let expected = list_view( + buffer![5i32, 6, 7, 200, 201].into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 2].into_array(), + Validity::AllValid, + ); + assert_arrays_eq!(result, expected); + Ok(()) + } + + /// When an input's `elements` is already a [`ChunkedArray`], its chunks are spliced in rather + /// than nesting a chunked array inside the concatenated elements. + #[test] + fn zip_flattens_chunked_elements() -> VortexResult<()> { + // elements [1, 2, 3] stored as two chunks; lists [[1, 2], [3]]. + let chunked_elements = ChunkedArray::try_new( + vec![buffer![1i32, 2].into_array(), buffer![3i32].into_array()], + DType::Primitive(PType::I32, Nullability::NonNullable), + )? + .into_array(); + let if_true = list_view( + chunked_elements, + buffer![0u32, 2].into_array(), + buffer![2u32, 1].into_array(), + Validity::NonNullable, + ); + // [[10], [20]] + let if_false = list_view( + buffer![10i32, 20].into_array(), + buffer![0u32, 1].into_array(), + buffer![1u32, 1].into_array(), + Validity::NonNullable, + ); + let mask = Mask::from_iter([true, false]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + + // The concatenated elements are chunked, but no chunk is itself a `ChunkedArray`. + let result_lv = result + .as_opt::() + .expect("zip keeps the list-view encoding"); + let chunked = result_lv + .elements() + .as_opt::() + .expect("zip concatenates elements into a chunked array"); + assert!( + chunked.iter_chunks().all(|chunk| !chunk.is::()), + "chunked elements must be flattened, not nested", + ); + + // [[1, 2], [20]] + let expected = list_view( + buffer![1i32, 2, 20].into_array(), + buffer![0u32, 2].into_array(), + buffer![2u32, 1].into_array(), + Validity::NonNullable, + ); + assert_arrays_eq!(result, expected); + Ok(()) + } +} diff --git a/vortex-array/src/arrays/listview/tests/operations.rs b/vortex-array/src/arrays/listview/tests/operations.rs index 235caf53caa..c911b9ba7a5 100644 --- a/vortex-array/src/arrays/listview/tests/operations.rs +++ b/vortex-array/src/arrays/listview/tests/operations.rs @@ -5,11 +5,13 @@ use std::sync::Arc; use rstest::rstest; use vortex_buffer::buffer; +use vortex_error::VortexResult; use vortex_mask::Mask; use super::common::create_basic_listview; use super::common::create_large_listview; use super::common::create_nullable_listview; +use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; #[expect(deprecated)] @@ -382,6 +384,100 @@ fn test_cast_large_dataset() { } } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Zip tests +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn test_zip_widens_false_element_nullability() -> VortexResult<()> { + // [[1, 2], [3], [4]] + let if_true = ListViewArray::new( + buffer![1i32, 2, 3, 4].into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + // [[10, null], [30], [40]] + let if_false = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), Some(40)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + let mask = Mask::from_iter([false, true, false]); + + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + assert!(result.is::()); + assert_eq!( + result.dtype(), + &DType::List( + Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)), + Nullability::NonNullable, + ) + ); + + // [[10, null], [3], [40]] + let expected = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(10i32), None, Some(3), Some(40)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + assert_arrays_eq!(result, expected); + Ok(()) +} + +#[test] +fn test_zip_widens_true_element_nullability() -> VortexResult<()> { + // [[1, null], [3], [4]] + let if_true = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + // [[10], [20], [30]] + let if_false = ListViewArray::new( + buffer![10i32, 20, 30].into_array(), + buffer![0u32, 1, 2].into_array(), + buffer![1u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + let mask = Mask::from_iter([true, false, true]); + + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + assert!(result.is::()); + assert_eq!( + result.dtype(), + &DType::List( + Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)), + Nullability::NonNullable, + ) + ); + + // [[1, null], [20], [4]] + let expected = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(1i32), None, Some(20), Some(4)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + assert_arrays_eq!(result, expected); + Ok(()) +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // Constant tests //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/vortex-array/src/arrays/listview/vtable/kernel.rs b/vortex-array/src/arrays/listview/vtable/kernel.rs index f6ceca284bf..1ad98f62a33 100644 --- a/vortex-array/src/arrays/listview/vtable/kernel.rs +++ b/vortex-array/src/arrays/listview/vtable/kernel.rs @@ -4,6 +4,9 @@ use crate::arrays::ListView; use crate::kernel::ParentKernelSet; use crate::scalar_fn::fns::cast::CastExecuteAdaptor; +use crate::scalar_fn::fns::zip::ZipExecuteAdaptor; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&CastExecuteAdaptor(ListView))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CastExecuteAdaptor(ListView)), + ParentKernelSet::lift(&ZipExecuteAdaptor(ListView)), +]); diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 86b5c4f7dc1..b1b17e40bb6 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -90,20 +90,12 @@ impl ScalarFnVTable for Zip { } fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - vortex_ensure!( - arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]), - "zip requires if_true and if_false to have the same base type, got {} and {}", - arg_dtypes[0], - arg_dtypes[1] - ); vortex_ensure!( matches!(arg_dtypes[2], DType::Bool(_)), "zip requires mask to be a boolean type, got {}", arg_dtypes[2] ); - Ok(arg_dtypes[0] - .clone() - .union_nullability(arg_dtypes[1].nullability())) + zip_return_dtype(&arg_dtypes[0], &arg_dtypes[1]) } fn execute( @@ -120,10 +112,7 @@ impl ScalarFnVTable for Zip { .execute::(ctx)? .to_mask_fill_null_false(ctx); - let return_dtype = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); + let return_dtype = zip_return_dtype(if_true.dtype(), if_false.dtype())?; if mask.all_true() { return if_true.cast(return_dtype)?.execute(ctx); @@ -184,10 +173,7 @@ pub(crate) fn zip_impl( "zip requires arrays to have the same size" ); - let return_type = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); + let return_type = zip_return_dtype(if_true.dtype(), if_false.dtype())?; if mask.all_true() { return if_true.cast(return_type); @@ -211,6 +197,18 @@ pub(crate) fn zip_impl( ) } +fn zip_return_dtype(if_true: &DType, if_false: &DType) -> VortexResult { + vortex_ensure!( + if_true.eq_ignore_nullability(if_false), + "zip requires if_true and if_false to have the same base type, got {} and {}", + if_true, + if_false + ); + Ok(if_true + .least_supertype(if_false) + .vortex_expect("zip inputs with the same base type must have a common dtype")) +} + fn zip_impl_with_builder( if_true: &ArrayRef, if_false: &ArrayRef,