diff --git a/Cargo.lock b/Cargo.lock index 79e0c622cae..b6f67682c4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10348,8 +10348,11 @@ dependencies = [ name = "vortex-json" version = "0.1.0" dependencies = [ + "arrow-array", + "arrow-schema", "vortex-array", "vortex-error", + "vortex-session", ] [[package]] diff --git a/encodings/datetime-parts/src/canonical.rs b/encodings/datetime-parts/src/canonical.rs index 78dc8e689c1..3a46eb6f328 100644 --- a/encodings/datetime-parts/src/canonical.rs +++ b/encodings/datetime-parts/src/canonical.rs @@ -143,12 +143,11 @@ mod test { &mut ctx, )?; - assert!( - date_times - .as_array() - .validity()? - .mask_eq(&validity, &mut ctx)? - ); + assert!(date_times.as_array().validity()?.mask_eq( + &validity, + milliseconds.len(), + &mut ctx + )?); let dtype = date_times.dtype().clone(); let parts = DateTimePartsParts { @@ -163,7 +162,6 @@ mod test { .execute::(&mut ctx)?; assert_arrays_eq!(primitive_values, milliseconds); - assert!(primitive_values.validity()?.mask_eq(&validity, &mut ctx)?); Ok(()) } } diff --git a/encodings/datetime-parts/src/compress.rs b/encodings/datetime-parts/src/compress.rs index 676e7bbfd43..6caf3c9f643 100644 --- a/encodings/datetime-parts/src/compress.rs +++ b/encodings/datetime-parts/src/compress.rs @@ -103,7 +103,7 @@ mod tests { days_prim .validity() .vortex_expect("days validity should be derivable") - .mask_eq(&validity, &mut ctx) + .mask_eq(&validity, days_prim.len(), &mut ctx) .unwrap() ); let seconds_prim = seconds.execute::(&mut ctx).unwrap(); diff --git a/encodings/pco/src/tests.rs b/encodings/pco/src/tests.rs index 6e5b4f9841b..20fec7dd9e0 100644 --- a/encodings/pco/src/tests.rs +++ b/encodings/pco/src/tests.rs @@ -149,6 +149,7 @@ fn test_validity_and_multiple_chunks_and_pages() { .unwrap() .mask_eq( &Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()), + primitive.len(), &mut ctx, ) .unwrap() diff --git a/encodings/zstd/src/test.rs b/encodings/zstd/src/test.rs index 7ed22886b82..094c5ef60cc 100644 --- a/encodings/zstd/src/test.rs +++ b/encodings/zstd/src/test.rs @@ -89,7 +89,7 @@ fn test_zstd_with_validity_and_multi_frame() { decompressed .validity() .unwrap() - .mask_eq(&array.validity().unwrap(), &mut ctx) + .mask_eq(&array.validity().unwrap(), decompressed.len(), &mut ctx) .unwrap() ); @@ -106,6 +106,7 @@ fn test_zstd_with_validity_and_multi_frame() { .unwrap() .mask_eq( &Validity::Array(BoolArray::from_iter(vec![false, true, false]).into_array()), + primitive.len(), &mut ctx ) .unwrap() diff --git a/vortex-array/src/arrays/dict/vtable/mod.rs b/vortex-array/src/arrays/dict/vtable/mod.rs index fa8515dd986..34aa8198cf7 100644 --- a/vortex-array/src/arrays/dict/vtable/mod.rs +++ b/vortex-array/src/arrays/dict/vtable/mod.rs @@ -46,7 +46,6 @@ use crate::executor::ExecutionResult; use crate::require_child; use crate::scalar::Scalar; use crate::serde::ArrayChildren; -use crate::validity::Validity; mod kernel; mod operations; @@ -179,7 +178,7 @@ impl VTable for Dict { let array = require_child!(array, array.codes(), DictSlots::CODES => Primitive); - if matches!(array.codes().validity()?, Validity::AllInvalid) { + if array.codes().validity()?.definitely_all_invalid() { return Ok(ExecutionResult::done(ConstantArray::new( Scalar::null(array.dtype().as_nullable()), array.codes().len(), diff --git a/vortex-array/src/arrays/masked/tests.rs b/vortex-array/src/arrays/masked/tests.rs index 2721ba9b519..b26d101eb33 100644 --- a/vortex-array/src/arrays/masked/tests.rs +++ b/vortex-array/src/arrays/masked/tests.rs @@ -134,7 +134,7 @@ fn test_masked_child_preserves_length(#[case] validity: Validity) { array .validity() .vortex_expect("masked validity should be derivable") - .mask_eq(&validity, &mut ctx) + .mask_eq(&validity, array.len(), &mut ctx) .unwrap(), ); } diff --git a/vortex-array/src/arrays/masked/vtable/mod.rs b/vortex-array/src/arrays/masked/vtable/mod.rs index 4409c8ff3c6..6fcb1a617b0 100644 --- a/vortex-array/src/arrays/masked/vtable/mod.rs +++ b/vortex-array/src/arrays/masked/vtable/mod.rs @@ -165,7 +165,7 @@ impl VTable for Masked { let validity = array.masked_validity(); // Fast path: all masked means result is all nulls. - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(ExecutionResult::done( ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len()) .into_array(), diff --git a/vortex-array/src/arrays/primitive/array/top_value.rs b/vortex-array/src/arrays/primitive/array/top_value.rs index d3ee5eb5a65..e3ff2346f40 100644 --- a/vortex-array/src/arrays/primitive/array/top_value.rs +++ b/vortex-array/src/arrays/primitive/array/top_value.rs @@ -17,7 +17,6 @@ use crate::arrays::primitive::NativeValue; use crate::dtype::NativePType; use crate::match_each_native_ptype; use crate::scalar::PValue; -use crate::validity::Validity; impl PrimitiveArray { /// Compute most common present value of this array @@ -26,7 +25,7 @@ impl PrimitiveArray { return Ok(None); } - if matches!(self.validity()?, Validity::AllInvalid) { + if self.validity()?.definitely_all_invalid() { return Ok(None); } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs index 2ac376155e3..fd5bad28054 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs @@ -2,17 +2,22 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; +use vortex_error::vortex_bail; use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; +use crate::array::Array; use crate::array::ArrayView; use crate::array::ValidityVTable; +use crate::array::child_to_validity; +use crate::arrays::ConstantArray; use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::arrays::scalar_fn::vtable::ArrayExpr; use crate::arrays::scalar_fn::vtable::FakeEq; use crate::arrays::scalar_fn::vtable::ScalarFn; +use crate::dtype::Nullability; use crate::expr::Expression; use crate::expr::lit; use crate::scalar_fn::TypedScalarFnInstance; @@ -21,6 +26,36 @@ use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::root::Root; use crate::validity::Validity; +/// Convert an expression tree into a lazy array DAG without executing it. +/// +/// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals. +fn expr_to_lazy_array(expr: &Expression, row_count: usize) -> VortexResult { + // Handle Root expression - this should not happen in validity expressions + if expr.is::() { + vortex_bail!("Root expression cannot be converted in validity context"); + } + + // Handle Literal expression - create a constant array + if expr.is::() { + let scalar = expr.as_::(); + return Ok(ConstantArray::new(scalar.clone(), row_count).into_array()); + } + + // Handle ArrayExpr leaves - unwrap the array they hold + if expr.is::() { + return Ok(expr.as_::().0.clone()); + } + + // Recursively convert child expressions into lazy input arrays + let children: Vec = expr + .children() + .iter() + .map(|child| expr_to_lazy_array(child, row_count)) + .collect::>()?; + + Ok(Array::::try_new(expr.scalar_fn().clone(), children, row_count)?.into_array()) +} + /// Execute an expression tree recursively. /// /// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals. @@ -29,13 +64,13 @@ fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult { // Handle Root expression - this should not happen in validity expressions if expr.is::() { - vortex_error::vortex_bail!("Root expression cannot be executed in validity context"); + vortex_bail!("Root expression cannot be executed in validity context"); } // Handle Literal expression - create a constant array if expr.is::() { let scalar = expr.as_::(); - return Ok(crate::arrays::ConstantArray::new(scalar.clone(), row_count).into_array()); + return Ok(ConstantArray::new(scalar.clone(), row_count).into_array()); } // Recursively execute child expressions to get input arrays @@ -66,9 +101,26 @@ impl ValidityVTable for ScalarFn { .collect::>()?; let expr = Expression::try_new(array.scalar_fn().clone(), inputs)?; - let validity_expr = array.scalar_fn().validity(&expr)?; - // Execute the validity expression. All leaves are ArrayExpr nodes. - Ok(Validity::Array(execute_expr(&validity_expr, array.len())?)) + match array.scalar_fn().validity_opt(&expr)? { + Some(validity_expr) => { + // The function defines its validity as an expression over its inputs, so we can + // represent it as a lazy array DAG without executing anything. If the expression + // is already a constant it is folded back into AllValid/AllInvalid. + let validity_array = expr_to_lazy_array(&validity_expr, array.len())?; + Ok(child_to_validity( + Some(&validity_array), + Nullability::Nullable, + )) + } + None => { + // The function's validity can only be determined by executing the function + // itself (e.g. Kleene logic and/or). Representing that lazily would create a + // self-referential array (is_not_null over this very expression), so execute it + // eagerly instead. + let validity_expr = array.scalar_fn().validity(&expr)?; + Ok(Validity::Array(execute_expr(&validity_expr, array.len())?)) + } + } } } diff --git a/vortex-array/src/arrays/varbin/array.rs b/vortex-array/src/arrays/varbin/array.rs index 4d435471ce0..798336e9c4a 100644 --- a/vortex-array/src/arrays/varbin/array.rs +++ b/vortex-array/src/arrays/varbin/array.rs @@ -255,7 +255,7 @@ impl VarBinData { } _ => None, }; - let all_invalid = matches!(validity, Validity::AllInvalid); + let all_invalid = validity.definitely_all_invalid(); match_each_integer_ptype!(primitive_offsets.dtype().as_ptype(), |O| { let offsets_slice = primitive_offsets.as_slice::(); diff --git a/vortex-array/src/builders/bool.rs b/vortex-array/src/builders/bool.rs index fdae7984844..e829a58f6ec 100644 --- a/vortex-array/src/builders/bool.rs +++ b/vortex-array/src/builders/bool.rs @@ -209,11 +209,11 @@ mod tests { #[expect(deprecated)] let into_canon = chunk.to_bool(); - assert!( - canon_into - .validity()? - .mask_eq(&into_canon.validity()?, &mut ctx)? - ); + assert!(canon_into.validity()?.mask_eq( + &into_canon.validity()?, + canon_into.len(), + &mut ctx + )?); assert_eq!(canon_into.to_bit_buffer(), into_canon.to_bit_buffer()); Ok(()) } diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index c7b506f228a..ac97230daa8 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -490,6 +490,7 @@ mod tests { &expected .validity() .vortex_expect("list validity should be derivable"), + actual.len(), &mut ctx, ) .unwrap(), diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 10e82d25455..5c95a4fa225 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -134,12 +134,22 @@ impl ScalarFnRef { /// Transforms the expression into one representing the validity of this expression. pub fn validity(&self, expr: &Expression) -> VortexResult { - Ok(self.0.validity(expr)?.unwrap_or_else(|| { + Ok(self.validity_opt(expr)?.unwrap_or_else(|| { // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback. IsNotNull.new_expr(EmptyOptions, [expr.clone()]) })) } + /// Transforms the expression into one representing the validity of this expression, + /// returning `None` if the function does not define a validity expression. + /// + /// When `None` is returned, the validity can only be determined by executing the + /// expression itself (e.g. Kleene logic `and`/`or`), and [`Self::validity`] falls back to + /// `is_not_null` over the expression. + pub fn validity_opt(&self, expr: &Expression) -> VortexResult> { + self.0.validity(expr) + } + /// Execute the expression given the input arguments. pub fn execute( &self, diff --git a/vortex-array/src/scalar_fn/fns/fill_null/kernel.rs b/vortex-array/src/scalar_fn/fns/fill_null/kernel.rs index eea3dd6ef7b..9b957d86710 100644 --- a/vortex-array/src/scalar_fn/fns/fill_null/kernel.rs +++ b/vortex-array/src/scalar_fn/fns/fill_null/kernel.rs @@ -79,7 +79,7 @@ pub(super) fn precondition( } // If all values are null, replace the entire array with the fill value. - if matches!(array.validity()?, Validity::AllInvalid) { + if array.validity()?.definitely_all_invalid() { return Ok(Some( ConstantArray::new(fill_value.clone(), array.len()).into_array(), )); diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 978a1da1caf..4d65ecd83dd 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -415,7 +415,7 @@ fn list_is_not_empty( ctx: &mut ExecutionCtx, ) -> VortexResult { // Short-circuit for all invalid. - if matches!(list_array.validity()?, Validity::AllInvalid) { + if list_array.validity()?.definitely_all_invalid() { return Ok(ConstantArray::new( Scalar::null(DType::Bool(Nullability::Nullable)), list_array.len(), diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index bf342a95cdd..52d354df1a0 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -23,18 +23,34 @@ pub(crate) use builtins::register_builtins; /// Shared reference to a stats rewrite rule. pub(crate) type StatsRewriteRuleRef = Arc; -/// A plugin-provided rule that rewrites predicates into stats-backed proof expressions. +/// A plugin-provided rule for predicates whose root scalar function matches this rule. /// -/// A falsifier evaluates to `true` only when the original predicate is definitely false for the -/// current stats scope. A satisfier evaluates to `true` only when the original predicate is -/// definitely true for the current stats scope. Returning `None` means the rule cannot prove -/// anything for the expression. -#[allow(dead_code)] +/// Rules do not produce expressions equivalent to `expr`. They produce optional sufficient +/// conditions over stats for the current scope: +/// +/// - a falsifier evaluating to `true` proves that `expr` is false for every row in the scope; +/// - a satisfier evaluating to `true` proves that `expr` is true for every row in the scope. +/// +/// Returning `None` means this rule cannot prove anything for the expression. A returned proof +/// expression that evaluates to `false` or `null` is also inconclusive. +/// +/// Multiple rules may be registered for the same scalar function. Their proofs are combined with +/// `OR`, so every proof returned by an individual rule must be sound on its own. +/// +/// `expr` is the full predicate expression whose root scalar function id is +/// [`Self::scalar_fn_id`]. Use [`StatsRewriteCtx`] to resolve dtypes and recursively rewrite child +/// predicates. pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { - /// The scalar function ID this rule applies to. + /// Returns the scalar function id handled by this rule. fn scalar_fn_id(&self) -> ScalarFnId; - /// Rewrite an expression into a stats-backed falsifier. + /// Returns a stats-backed proof that `expr` is false for the current scope. + /// + /// If the returned expression evaluates to `true` against the scope's stats, then `expr` is + /// guaranteed to be false for every row in that scope. A returned proof expression that + /// evaluates to `false` or `null` is inconclusive. + /// + /// Returns `Ok(None)` when this rule cannot construct a sound falsity proof for `expr`. fn falsify( &self, expr: &Expression, @@ -45,7 +61,16 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { Ok(None) } - /// Rewrite an expression into a stats-backed satisfier. + /// Returns a stats-backed proof that `expr` is true for the current scope. + /// + /// If the returned expression evaluates to `true` against the scope's stats, then `expr` is + /// guaranteed to be true for every row in that scope. A returned proof expression that + /// evaluates to `false` or `null` is inconclusive. + /// + /// This is not the complement of [`Self::falsify`]; both methods are one-way proofs and may be + /// implemented independently. + /// + /// Returns `Ok(None)` when this rule cannot construct a sound truth proof for `expr`. fn satisfy( &self, expr: &Expression, diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index f3a77b4759e..431ac805151 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -123,6 +123,17 @@ impl Validity { matches!(self, Self::NonNullable | Self::AllValid) } + /// Returns `true` if this validity is *definitely* all-invalid, i.e. it is + /// [`Validity::AllInvalid`]. + /// + /// Returning `false` does not prove the presence of valid values: a [`Validity::Array`] may + /// still resolve to all-invalid once executed. Callers must treat `false` as "unknown + /// without compute". This is the all-invalid counterpart to [`Self::definitely_no_nulls`]. + #[inline] + pub fn definitely_all_invalid(&self) -> bool { + matches!(self, Self::AllInvalid) + } + /// Returns whether this validity contains no null values, executing the validity array if /// necessary. /// @@ -244,6 +255,7 @@ impl Validity { } } + #[inline] pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult { match self { Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)), @@ -263,18 +275,47 @@ impl Validity { } } - /// Compare two Validity values of the same length by executing them into masks if necessary. - pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult { + /// Compare the logical masks of two Validity values of the given length, executing them + /// into [`Mask`]s if necessary. + /// + /// Mixed `Array`-vs-constant pairings are answered from statistics where possible (the + /// minimum/maximum of the validity array decides all-valid/all-invalid exactly), only + /// falling back to executing the validity array when statistics are unavailable. + pub fn mask_eq( + &self, + other: &Validity, + length: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { match (self, other) { - (Validity::NonNullable, Validity::NonNullable) => Ok(true), - (Validity::AllValid, Validity::AllValid) => Ok(true), - (Validity::AllInvalid, Validity::AllInvalid) => Ok(true), - (Validity::Array(a), Validity::Array(b)) => { - let a = a.clone().execute::(ctx)?; - let b = b.clone().execute::(ctx)?; - Ok(a == b) + // Fast paths that avoid executing: constant variants with known-equal masks. + ( + Validity::NonNullable | Validity::AllValid, + Validity::NonNullable | Validity::AllValid, + ) + | (Validity::AllInvalid, Validity::AllInvalid) => Ok(true), + // Constant variants with opposite masks: only equal when empty. + (Validity::NonNullable | Validity::AllValid, Validity::AllInvalid) + | (Validity::AllInvalid, Validity::NonNullable | Validity::AllValid) => Ok(length == 0), + // Array vs all-valid: equal iff the array's minimum is true. + (Validity::Array(a), Validity::NonNullable | Validity::AllValid) + | (Validity::NonNullable | Validity::AllValid, Validity::Array(a)) => { + match a.statistics().compute_min::(ctx) { + Some(min) => Ok(min), + None => Ok(a.clone().execute::(ctx)?.all_true()), + } + } + // Array vs all-invalid: equal iff the array's maximum is false. + (Validity::Array(a), Validity::AllInvalid) + | (Validity::AllInvalid, Validity::Array(a)) => { + match a.statistics().compute_max::(ctx) { + Some(max) => Ok(!max), + None => Ok(a.clone().execute::(ctx)?.all_false()), + } + } + (Validity::Array(_), Validity::Array(_)) => { + Ok(self.execute_mask(length, ctx)? == other.execute_mask(length, ctx)?) } - _ => Ok(false), } } @@ -703,7 +744,7 @@ mod tests { validity .patch(len, 0, &indices, &patches, &mut ctx,) .unwrap() - .mask_eq(&expected, &mut ctx) + .mask_eq(&expected, len, &mut ctx) .unwrap() ); } @@ -768,8 +809,50 @@ mod tests { validity .take(&indices) .unwrap() - .mask_eq(&expected, &mut ctx) + .mask_eq(&expected, indices.len(), &mut ctx) .unwrap() ); } + + #[rstest] + // Mixed constant variants with equal masks. + #[case(Validity::NonNullable, Validity::AllValid, true)] + #[case(Validity::AllValid, Validity::NonNullable, true)] + #[case(Validity::AllValid, Validity::AllInvalid, false)] + #[case(Validity::NonNullable, Validity::AllInvalid, false)] + // An array that resolves to a constant mask must equal the constant variant. + #[case( + Validity::Array(BoolArray::from_iter([true, true, true]).into_array()), + Validity::AllValid, + true + )] + #[case( + Validity::NonNullable, + Validity::Array(BoolArray::from_iter([true, true, true]).into_array()), + true + )] + #[case( + Validity::Array(BoolArray::from_iter([false, false, false]).into_array()), + Validity::AllInvalid, + true + )] + #[case( + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + Validity::AllValid, + false + )] + #[case( + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + Validity::AllInvalid, + false + )] + fn mask_eq_mixed_variants( + #[case] lhs: Validity, + #[case] rhs: Validity, + #[case] expected: bool, + ) -> vortex_error::VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(lhs.mask_eq(&rhs, 3, &mut ctx)?, expected); + Ok(()) + } } diff --git a/vortex-cuda/src/dynamic_dispatch/mod.rs b/vortex-cuda/src/dynamic_dispatch/mod.rs index 186f2297db1..66dd8de7371 100644 --- a/vortex-cuda/src/dynamic_dispatch/mod.rs +++ b/vortex-cuda/src/dynamic_dispatch/mod.rs @@ -32,7 +32,6 @@ use vortex::array::buffer::BufferHandle; use vortex::array::buffer::DeviceBufferExt; use vortex::array::match_each_unsigned_integer_ptype; use vortex::array::scalar::Scalar; -use vortex::array::validity::Validity; use vortex::buffer::Alignment; use vortex::buffer::ByteBuffer; use vortex::buffer::ByteBufferMut; @@ -434,7 +433,7 @@ impl MaterializedPlan { let output_ptype = self.dispatch_plan.output_ptype(); // All values are null — no need to touch the GPU. - if matches!(self.validity, Validity::AllInvalid) { + if self.validity.definitely_all_invalid() { let dtype = DType::Primitive(output_ptype, Nullability::Nullable); return ConstantArray::new(Scalar::null(dtype), len) .into_array() diff --git a/vortex-cuda/src/kernel/encodings/date_time_parts.rs b/vortex-cuda/src/kernel/encodings/date_time_parts.rs index d8c655c2558..3334ee34b9b 100644 --- a/vortex-cuda/src/kernel/encodings/date_time_parts.rs +++ b/vortex-cuda/src/kernel/encodings/date_time_parts.rs @@ -72,7 +72,7 @@ impl CudaExecute for DateTimePartsExecutor { return Ok(Canonical::empty(array.dtype())); } - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { let storage_ptype = ext.storage_dtype().as_ptype(); return Ok(Canonical::Extension( TemporalArray::new_timestamp( diff --git a/vortex-cuda/src/kernel/encodings/fsst.rs b/vortex-cuda/src/kernel/encodings/fsst.rs index 5d3d66eaf04..f511e6aabba 100644 --- a/vortex-cuda/src/kernel/encodings/fsst.rs +++ b/vortex-cuda/src/kernel/encodings/fsst.rs @@ -23,7 +23,6 @@ use vortex::array::arrays::varbinview::build_views::build_views; use vortex::array::buffer::DeviceBuffer; use vortex::array::match_each_integer_ptype; use vortex::array::match_each_unsigned_integer_ptype; -use vortex::array::validity::Validity; use vortex::buffer::Alignment; use vortex::buffer::Buffer; use vortex::dtype::NativePType; @@ -62,7 +61,7 @@ impl CudaExecute for FSSTExecutor { let dtype = fsst.dtype().clone(); let validity = fsst.codes().validity()?; - if fsst.is_empty() || matches!(validity, Validity::AllInvalid) { + if fsst.is_empty() || validity.definitely_all_invalid() { let empty = unsafe { VarBinViewArray::new_unchecked( Buffer::::zeroed(fsst.len()), diff --git a/vortex-cuda/src/kernel/encodings/runend.rs b/vortex-cuda/src/kernel/encodings/runend.rs index fca435478d8..e86e8917edc 100644 --- a/vortex-cuda/src/kernel/encodings/runend.rs +++ b/vortex-cuda/src/kernel/encodings/runend.rs @@ -75,7 +75,7 @@ impl CudaExecute for RunEndExecutor { ))); } - if matches!(values.validity()?, Validity::AllInvalid) { + if values.validity()?.definitely_all_invalid() { return ConstantArray::new(Scalar::null(values.dtype().clone()), output_len) .into_array() .execute::(ctx.execution_ctx()); diff --git a/vortex-duckdb/src/exporter/bool.rs b/vortex-duckdb/src/exporter/bool.rs index 84fd17f0789..5b977ec7375 100644 --- a/vortex-duckdb/src/exporter/bool.rs +++ b/vortex-duckdb/src/exporter/bool.rs @@ -4,7 +4,6 @@ use vortex::array::ExecutionCtx; use vortex::array::arrays::BoolArray; use vortex::array::arrays::bool::BoolArrayExt; -use vortex::array::validity::Validity; use vortex::buffer::BitBuffer; use vortex::error::VortexResult; use vortex::mask::Mask; @@ -26,7 +25,7 @@ pub(crate) fn new_exporter( let bits = array.to_bit_buffer(); let validity = array.validity()?; - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); } let validity = validity.to_array(len).execute::(ctx)?; diff --git a/vortex-duckdb/src/exporter/decimal.rs b/vortex-duckdb/src/exporter/decimal.rs index f6b8d9607e7..674e4ce11ff 100644 --- a/vortex-duckdb/src/exporter/decimal.rs +++ b/vortex-duckdb/src/exporter/decimal.rs @@ -8,7 +8,6 @@ use vortex::array::ExecutionCtx; use vortex::array::arrays::DecimalArray; use vortex::array::arrays::decimal::DecimalDataParts; use vortex::array::match_each_decimal_value_type; -use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::dtype::BigCast; use vortex::dtype::DecimalDType; @@ -49,7 +48,7 @@ pub(crate) fn new_exporter( } = array.into_data_parts(); let dest_values_type = precision_to_duckdb_storage_size(&decimal_dtype)?; - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); } let validity = validity.to_array(len).execute::(ctx)?; diff --git a/vortex-duckdb/src/exporter/fixed_size_list.rs b/vortex-duckdb/src/exporter/fixed_size_list.rs index ed93ad2b9c1..d9be4978223 100644 --- a/vortex-duckdb/src/exporter/fixed_size_list.rs +++ b/vortex-duckdb/src/exporter/fixed_size_list.rs @@ -11,7 +11,6 @@ use vortex::array::ExecutionCtx; use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex::array::validity::Validity; use vortex::error::VortexResult; use vortex::mask::Mask; @@ -43,7 +42,7 @@ pub(crate) fn new_exporter( let elements = parts.elements; let validity = parts.validity; - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); } diff --git a/vortex-duckdb/src/exporter/list.rs b/vortex-duckdb/src/exporter/list.rs index a0eb65ffd64..dbc2985d560 100644 --- a/vortex-duckdb/src/exporter/list.rs +++ b/vortex-duckdb/src/exporter/list.rs @@ -11,7 +11,6 @@ use vortex::array::arrays::ListArray; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::list::ListDataParts; use vortex::array::match_each_integer_ptype; -use vortex::array::validity::Validity; use vortex::dtype::IntegerPType; use vortex::error::VortexResult; use vortex::error::vortex_ensure; @@ -55,7 +54,7 @@ pub(crate) fn new_exporter( } = array.into_data_parts(); let num_elements = elements.len(); - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); } let validity = validity.to_array(array_len).execute::(ctx)?; diff --git a/vortex-duckdb/src/exporter/list_view.rs b/vortex-duckdb/src/exporter/list_view.rs index a4cb61895f3..8c667c86f85 100644 --- a/vortex-duckdb/src/exporter/list_view.rs +++ b/vortex-duckdb/src/exporter/list_view.rs @@ -16,7 +16,6 @@ use vortex::array::arrays::listview::ListViewArrayExt; use vortex::array::arrays::listview::ListViewDataParts; use vortex::array::arrays::listview::ListViewRebuildMode; use vortex::array::match_each_integer_ptype; -use vortex::array::validity::Validity; use vortex::dtype::IntegerPType; use vortex::error::VortexExpect; use vortex::error::VortexResult; @@ -92,7 +91,7 @@ pub(crate) fn new_exporter( // Cache an `elements` vector up front so that future exports can reference it. let num_elements = elements.len(); - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); } let validity = validity.to_array(len).execute::(ctx)?; diff --git a/vortex-duckdb/src/exporter/primitive.rs b/vortex-duckdb/src/exporter/primitive.rs index a0b08c80e4b..803ad2fca9f 100644 --- a/vortex-duckdb/src/exporter/primitive.rs +++ b/vortex-duckdb/src/exporter/primitive.rs @@ -6,7 +6,6 @@ use std::marker::PhantomData; use vortex::array::ExecutionCtx; use vortex::array::arrays::PrimitiveArray; use vortex::array::match_each_native_ptype; -use vortex::array::validity::Validity; use vortex::dtype::NativePType; use vortex::error::VortexResult; use vortex::mask::Mask; @@ -29,7 +28,7 @@ pub fn new_exporter( ctx: &mut ExecutionCtx, ) -> VortexResult> { let validity = array.validity()?; - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); }; let validity = validity.to_array(array.len()).execute::(ctx)?; diff --git a/vortex-duckdb/src/exporter/struct_.rs b/vortex-duckdb/src/exporter/struct_.rs index 76c07d672a3..f1df6899a00 100644 --- a/vortex-duckdb/src/exporter/struct_.rs +++ b/vortex-duckdb/src/exporter/struct_.rs @@ -8,7 +8,6 @@ use vortex::array::arrays::StructArray; use vortex::array::arrays::bool::BoolArrayExt; use vortex::array::arrays::struct_::StructDataParts; use vortex::array::builtins::ArrayBuiltins; -use vortex::array::validity::Validity; use vortex::error::VortexResult; use crate::duckdb::VectorRef; @@ -35,7 +34,7 @@ pub(crate) fn new_exporter( .. } = array.into_data_parts(); - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); }; let validity = validity.to_array(len).execute::(ctx)?; diff --git a/vortex-duckdb/src/exporter/varbinview.rs b/vortex-duckdb/src/exporter/varbinview.rs index 557795f45f3..cc99daafeb7 100644 --- a/vortex-duckdb/src/exporter/varbinview.rs +++ b/vortex-duckdb/src/exporter/varbinview.rs @@ -9,7 +9,6 @@ use vortex::array::arrays::VarBinViewArray; use vortex::array::arrays::varbinview::BinaryView; use vortex::array::arrays::varbinview::Inlined; use vortex::array::arrays::varbinview::VarBinViewDataParts; -use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::buffer::ByteBuffer; use vortex::error::VortexResult; @@ -39,7 +38,7 @@ pub(crate) fn new_exporter( buffers, } = array.into_data_parts(); - if matches!(validity, Validity::AllInvalid) { + if validity.definitely_all_invalid() { return Ok(all_invalid::new_exporter()); } let validity = validity.to_array(len).execute::(ctx)?; diff --git a/vortex-json/Cargo.toml b/vortex-json/Cargo.toml index 3b693b96f29..4afd224f6aa 100644 --- a/vortex-json/Cargo.toml +++ b/vortex-json/Cargo.toml @@ -17,5 +17,8 @@ version = { workspace = true } workspace = true [dependencies] +arrow-array = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } vortex-array = { workspace = true, default-features = false } vortex-error = { workspace = true, default-features = false } +vortex-session = { workspace = true } diff --git a/vortex-json/src/arrow.rs b/vortex-json/src/arrow.rs new file mode 100644 index 00000000000..499a4144131 --- /dev/null +++ b/vortex-json/src/arrow.rs @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Arrow import and export support for the JSON extension dtype. + +use arrow_array::ArrayRef as ArrowArrayRef; +use arrow_schema::Field; +use arrow_schema::extension::ExtensionType; +use arrow_schema::extension::Json as ArrowJson; +use vortex_array::ArrayRef; +use vortex_array::EmptyMetadata; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrow::ArrowExport; +use vortex_array::arrow::ArrowExportVTable; +use vortex_array::arrow::ArrowImport; +use vortex_array::arrow::ArrowImportVTable; +use vortex_array::arrow::ArrowSession; +use vortex_array::arrow::ArrowSessionExt; +use vortex_array::arrow::FromArrowArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtVTable; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_session::registry::CachedId; +use vortex_session::registry::Id; + +use crate::Json; + +/// Arrow's canonical JSON extension name cached as a registry id. +static ARROW_JSON: CachedId = CachedId::new(ArrowJson::NAME); + +/// Returns whether an Arrow field contains valid canonical JSON extension metadata. +fn has_valid_json_extension(field: &Field) -> bool { + field.extension_type_name() == Some(ArrowJson::NAME) + && ArrowJson::try_new_from_field_metadata(field.data_type(), field.metadata()).is_ok() +} + +impl ArrowExportVTable for Json { + fn arrow_ext_id(&self) -> Id { + *ARROW_JSON + } + + fn vortex_id(&self) -> Id { + Json.id() + } + + fn to_arrow_field( + &self, + name: &str, + dtype: &DType, + session: &ArrowSession, + ) -> VortexResult> { + let DType::Extension(ext_dtype) = dtype else { + return Ok(None); + }; + if !ext_dtype.is::() { + return Ok(None); + } + + let mut field = session.to_arrow_field(name, ext_dtype.storage_dtype())?; + field + .try_with_extension_type(ArrowJson::default()) + .vortex_expect("Utf8 is a valid storage type for Arrow JSON"); + Ok(Some(field)) + } + + fn execute_arrow( + &self, + array: ArrayRef, + target: &Field, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let is_json = array + .dtype() + .as_extension_opt() + .map(|ext_dtype| ext_dtype.is::()) + .unwrap_or(false); + if !is_json { + return Ok(ArrowExport::Unsupported(array)); + } + + ArrowJson::try_new_from_field_metadata(target.data_type(), target.metadata())?; + + let executed = array.execute::(ctx)?; + let storage = executed.storage_array().clone(); + let storage_field = Field::new( + String::new(), + target.data_type().clone(), + target.is_nullable(), + ); + let session = ctx.session().clone(); + + let storage = session + .arrow() + .execute_arrow(storage, Some(&storage_field), ctx)?; + + Ok(ArrowExport::Exported(storage)) + } +} + +impl ArrowImportVTable for Json { + fn arrow_ext_id(&self) -> Id { + *ARROW_JSON + } + + fn from_arrow_field(&self, field: &Field) -> VortexResult> { + if !has_valid_json_extension(field) { + return Ok(None); + } + + Ok(Some(DType::Extension( + ExtDType::::try_new(EmptyMetadata, DType::Utf8(field.is_nullable().into()))? + .erased(), + ))) + } + + fn from_arrow_array( + &self, + array: ArrowArrayRef, + field: &Field, + dtype: &DType, + ) -> VortexResult { + let DType::Extension(ext_dtype) = dtype else { + return Ok(ArrowImport::Unsupported(array)); + }; + if !ext_dtype.is::() || !has_valid_json_extension(field) { + return Ok(ArrowImport::Unsupported(array)); + } + + let storage = ArrayRef::from_arrow(array.as_ref(), field.is_nullable())?; + Ok(ArrowImport::Imported( + ExtensionArray::new(ext_dtype.clone(), storage).into_array(), + )) + } +} + +#[cfg(test)] +mod tests { + + use std::sync::Arc; + + use arrow_array::Array; + use arrow_array::ArrayRef as ArrowArrayRef; + use arrow_array::StringArray; + use arrow_array::cast::AsArray; + use arrow_schema::DataType; + use arrow_schema::Field; + use arrow_schema::extension::ExtensionType; + use arrow_schema::extension::Json as ArrowJson; + use vortex_array::EmptyMetadata; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::ExtensionArray; + use vortex_array::arrays::VarBinArray; + use vortex_array::arrow::ArrowSessionExt; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::extension::ExtDType; + use vortex_error::VortexExpect; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::Json; + use crate::initialize; + + /// Export a JSON extension array to Arrow's canonical JSON extension. + #[test] + fn exports_json_extension_array_as_arrow_json() -> VortexResult<()> { + let session = VortexSession::empty(); + initialize(&session); + + let storage = VarBinArray::from_iter( + [Some("{\"id\":1}"), Some("{\"id\":2}")], + vortex_array::dtype::DType::Utf8(Nullability::NonNullable), + ) + .into_array(); + let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + + dbg!(&ext_dtype); + let array = ExtensionArray::new(ext_dtype, storage).into_array(); + + let field = session.arrow().to_arrow_field("data", array.dtype())?; + assert_eq!(field.extension_type_name(), Some(ArrowJson::NAME)); + ArrowJson::try_new_from_field_metadata(field.data_type(), field.metadata())?; + + dbg!(&field); + + let exported = session.arrow().execute_arrow( + array, + Some(&field), + &mut session.create_execution_ctx(), + )?; + + assert!(exported.data_type().is_string()); + + dbg!(exported.data_type()); + + let strings = exported.as_string_view(); + assert_eq!(strings.value(0), "{\"id\":1}"); + assert_eq!(strings.value(1), "{\"id\":2}"); + Ok(()) + } + + /// Import Arrow's canonical JSON extension as a Vortex JSON extension array. + #[test] + fn imports_arrow_json_extension_array_as_vortex_json() -> VortexResult<()> { + let session = VortexSession::empty(); + initialize(&session); + + let mut field = Field::new("data", DataType::Utf8, false); + field.try_with_extension_type(ArrowJson::default())?; + let array = Arc::new(StringArray::from(vec!["{\"id\":1}", "{\"id\":2}"])) as ArrowArrayRef; + + let imported = session.arrow().from_arrow_array(array, &field)?; + let ext_dtype = imported + .dtype() + .as_extension_opt() + .vortex_expect("expected JSON extension dtype"); + assert!(ext_dtype.is::()); + + let exported = session.arrow().execute_arrow( + imported, + Some(&field), + &mut session.create_execution_ctx(), + )?; + let strings = exported.as_string::(); + assert_eq!(strings.value(0), "{\"id\":1}"); + assert_eq!(strings.value(1), "{\"id\":2}"); + Ok(()) + } +} diff --git a/vortex-json/src/lib.rs b/vortex-json/src/lib.rs index 609ac44b861..fc7b0b1c9b3 100644 --- a/vortex-json/src/lib.rs +++ b/vortex-json/src/lib.rs @@ -9,6 +9,19 @@ //! Extension type and related functionality for a JSON extension type for Vortex. +mod arrow; mod dtype; +use std::sync::Arc; + pub use dtype::Json; +use vortex_array::arrow::ArrowSessionExt; +use vortex_array::dtype::session::DTypeSessionExt; +use vortex_session::VortexSession; + +/// Register JSON extension support with a session. +pub fn initialize(session: &VortexSession) { + session.dtypes().register(Json); + session.arrow().register_exporter(Arc::new(Json)); + session.arrow().register_importer(Arc::new(Json)); +} diff --git a/vortex-mask/src/eq.rs b/vortex-mask/src/eq.rs index 5cf1ff640df..6e43df8388f 100644 --- a/vortex-mask/src/eq.rs +++ b/vortex-mask/src/eq.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::mem; + use crate::Mask; impl PartialEq for Mask { @@ -9,6 +11,9 @@ impl PartialEq for Mask { if self.len() != other.len() { return false; } + if mem::discriminant(self) == mem::discriminant(other) && !matches!(self, Mask::Values(_)) { + return true; + } if self.true_count() != other.true_count() { return false; } diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index 3f7016be38f..baf0d0ae761 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -345,11 +345,11 @@ mod test { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let recovered_primitive = recovered_array.execute::(&mut ctx)?; - assert!( - recovered_primitive - .validity()? - .mask_eq(&array.validity()?, &mut ctx)? - ); + assert!(recovered_primitive.validity()?.mask_eq( + &array.validity()?, + array.len(), + &mut ctx + )?); assert_eq!( recovered_primitive.to_buffer::(), array.to_buffer::()