Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions encodings/datetime-parts/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,11 @@ fn compare_dtp(
#[cfg(test)]
mod test {
use rstest::rstest;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::aggregate_fn::fns::sum::sum;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::TemporalArray;
use vortex_array::dtype::IntegerPType;
Expand All @@ -230,27 +233,37 @@ mod test {
.expect("Failed to construct DateTimePartsArray from TemporalArray")
}

/// Count the true values in a boolean array using the provided execution context.
fn true_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> usize {
sum(array, ctx)
.unwrap()
.as_primitive()
.as_::<usize>()
.unwrap()
}

#[rstest]
#[case(Validity::NonNullable, Validity::NonNullable)]
#[case(Validity::NonNullable, Validity::AllValid)]
#[case(Validity::AllValid, Validity::NonNullable)]
#[case(Validity::AllValid, Validity::AllValid)]
fn compare_date_time_parts_eq(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 00:00:00 UTC
let rhs = dtp_array_from_timestamp(86400i64, rhs_validity.clone()); // January 2, 1970, 00:00:00 UTC
let comparison = lhs
.clone()
.into_array()
.binary(rhs.into_array(), Operator::Eq)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);

let rhs = dtp_array_from_timestamp(0i64, rhs_validity); // January 1, 1970, 00:00:00 UTC
let comparison = lhs
.into_array()
.binary(rhs.into_array(), Operator::Eq)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
assert_eq!(true_count(&comparison, &mut ctx), 0);
}

#[rstest]
Expand All @@ -259,21 +272,22 @@ mod test {
#[case(Validity::AllValid, Validity::NonNullable)]
#[case(Validity::AllValid, Validity::AllValid)]
fn compare_date_time_parts_ne(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 00:00:00 UTC
let rhs = dtp_array_from_timestamp(86401i64, rhs_validity.clone()); // January 2, 1970, 00:00:01 UTC
let comparison = lhs
.clone()
.into_array()
.binary(rhs.into_array(), Operator::NotEq)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);

let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); // January 2, 1970, 00:00:00 UTC
let comparison = lhs
.into_array()
.binary(rhs.into_array(), Operator::NotEq)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
assert_eq!(true_count(&comparison, &mut ctx), 0);
}

#[rstest]
Expand All @@ -282,14 +296,15 @@ mod test {
#[case(Validity::AllValid, Validity::NonNullable)]
#[case(Validity::AllValid, Validity::AllValid)]
fn compare_date_time_parts_lt(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let lhs = dtp_array_from_timestamp(0i64, lhs_validity); // January 1, 1970, 01:00:00 UTC
let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); // January 2, 1970, 00:00:00 UTC

let comparison = lhs
.into_array()
.binary(rhs.into_array(), Operator::Lt)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);
}

#[rstest]
Expand All @@ -298,14 +313,15 @@ mod test {
#[case(Validity::AllValid, Validity::NonNullable)]
#[case(Validity::AllValid, Validity::AllValid)]
fn compare_date_time_parts_gt(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 02:00:00 UTC
let rhs = dtp_array_from_timestamp(0i64, rhs_validity); // January 1, 1970, 01:00:00 UTC

let comparison = lhs
.into_array()
.binary(rhs.into_array(), Operator::Gt)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);
}

#[rstest]
Expand All @@ -317,6 +333,7 @@ mod test {
#[case] lhs_validity: Validity,
#[case] rhs_validity: Validity,
) {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let temporal_array = TemporalArray::new_timestamp(
PrimitiveArray::new(buffer![0i64], lhs_validity.clone()).into_array(),
TimeUnit::Seconds,
Expand All @@ -339,27 +356,27 @@ mod test {
.into_array()
.binary(rhs.clone().into_array(), Operator::Eq)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
assert_eq!(true_count(&comparison, &mut ctx), 0);

let comparison = lhs
.clone()
.into_array()
.binary(rhs.clone().into_array(), Operator::NotEq)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);

let comparison = lhs
.clone()
.into_array()
.binary(rhs.clone().into_array(), Operator::Lt)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);

let comparison = lhs
.into_array()
.binary(rhs.into_array(), Operator::Lte)
.unwrap();
assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&comparison, &mut ctx), 1);

// `CompareOperator::Gt` and `CompareOperator::Gte` only cover the case of all lhs values
// being larger. Therefore, these cases are not covered by unit tests.
Expand Down
28 changes: 21 additions & 7 deletions encodings/datetime-parts/src/compute/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,11 @@ fn is_constant_zero(array: &ArrayRef) -> bool {

#[cfg(test)]
mod tests {
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::aggregate_fn::fns::sum::sum;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::TemporalArray;
use vortex_array::arrays::scalar_fn::ScalarFnFactoryExt;
Expand All @@ -200,6 +203,15 @@ mod tests {

const SECONDS_PER_DAY: i64 = 86400;

/// Count the true values in a boolean array using the provided execution context.
fn true_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> usize {
sum(array, ctx)
.unwrap()
.as_primitive()
.as_::<usize>()
.unwrap()
}

/// Create a DTP array with the given day values (all at midnight).
fn dtp_at_midnight(days: &[i64], time_unit: TimeUnit) -> DateTimePartsArray {
let multiplier = match time_unit {
Expand Down Expand Up @@ -285,7 +297,8 @@ mod tests {
);

// Verify correctness: days [0, 1, 2] <= 1 should give [true, true, false]
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 2);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert_eq!(true_count(&optimized, &mut ctx), 2);
}

#[test]
Expand Down Expand Up @@ -313,7 +326,8 @@ mod tests {
let optimized = between.optimize().unwrap();

// Verify correctness: days [0, 1, 2, 3, 4] between 1 and 3 should give [false, true, true, true, false]
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 3);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert_eq!(true_count(&optimized, &mut ctx), 3);
}

#[test]
Expand All @@ -335,7 +349,8 @@ mod tests {
// (optimization doesn't apply, so we keep the original structure)
// Just verify it still computes correctly
// days [0, 1, 2] at midnight <= day 1 at noon: [true, true, false]
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 2);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert_eq!(true_count(&optimized, &mut ctx), 2);
}

#[test]
Expand All @@ -352,9 +367,8 @@ mod tests {
TimeUnit::Seconds,
None,
);
let dtp =
DateTimeParts::try_from_temporal(temporal, &mut LEGACY_SESSION.create_execution_ctx())
.unwrap();
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let dtp = DateTimeParts::try_from_temporal(temporal, &mut ctx).unwrap();
let len = dtp.len();

// Compare against midnight constant
Expand All @@ -366,6 +380,6 @@ mod tests {
// Should still compute correctly (just not optimized via pushdown)
let optimized = comparison.optimize().unwrap();
// timestamps at 1am on days [0, 1, 2] <= day 1 midnight: [true, false, false]
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 1);
assert_eq!(true_count(&optimized, &mut ctx), 1);
}
}
8 changes: 7 additions & 1 deletion encodings/experimental/onpair/tests/big_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::time::Instant;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::accessor::ArrayAccessor;
use vortex_array::aggregate_fn::fns::sum::sum;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::VarBinArray;
use vortex_array::arrays::VarBinViewArray;
Expand Down Expand Up @@ -121,6 +122,11 @@ fn smoke_100k_rows() {
.execute::<vortex_array::Canonical>(&mut ctx)
.unwrap()
.into_array();
assert_eq!(eq.as_bool_typed().true_count().unwrap(), want_eq);
let eq_count = sum(&eq, &mut ctx)
.unwrap()
.as_primitive()
.as_::<usize>()
.unwrap();
assert_eq!(eq_count, want_eq);
eprintln!("eq pushdown matches reference count ({})", want_eq);
}
4 changes: 2 additions & 2 deletions vortex-array/src/arrays/varbin/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::arrays::VarBin;
use crate::arrays::VarBinViewArray;
use crate::arrays::varbin::VarBinArrayExt;
use crate::arrow::Datum;
use crate::arrow::from_arrow_array_with_len;
use crate::arrow::from_arrow_columnar;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::IntegerPType;
Expand Down Expand Up @@ -125,7 +125,7 @@ impl CompareKernel for VarBin {
}
.map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;

Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
Ok(Some(from_arrow_columnar(&array, len, nullable, ctx)?))
} else if !rhs.is::<VarBin>() {
// NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
// Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
Expand Down
38 changes: 38 additions & 0 deletions vortex-array/src/arrow/datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ impl ArrowDatum for Datum {
/// # Error
///
/// The provided array must have length
#[deprecated(
note = "Relies on the hidden global `LEGACY_SESSION`; use `from_arrow_columnar` with an explicit `ExecutionCtx` instead"
)]

@AdamGS AdamGS Jun 8, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the note is wrong here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — updated in 94eae41 to point at the actual replacement: note = "Relies on the hidden global LEGACY_SESSION; use from_arrow_columnar with an explicit ExecutionCtx instead".


Generated by Claude Code

pub fn from_arrow_array_with_len<A>(array: A, len: usize, nullable: bool) -> VortexResult<ArrayRef>
where
ArrayRef: FromArrowArray<A>,
Expand All @@ -131,3 +134,38 @@ where
)
.into_array())
}

/// Convert an Arrow array to an Array with a specific length, using the provided
/// [`ExecutionCtx`].
///
/// This is useful for compute functions that delegate to Arrow using [Datum],
/// which will return a scalar (length 1 Arrow array) if the input array is constant.
///
/// # Error
///
/// The provided array must have length `len` or `1`.
pub fn from_arrow_columnar<A>(
array: A,
len: usize,
nullable: bool,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef>
where
ArrayRef: FromArrowArray<A>,
{
let array = ArrayRef::from_arrow(array, nullable)?;
if array.len() == len {
return Ok(array);
}

if array.len() != 1 {
vortex_panic!(
"Array length mismatch, expected {} got {} for encoding {}",
len,
array.len(),
array.encoding_id()
);
}

Ok(ConstantArray::new(array.execute_scalar(0, ctx)?, len).into_array())
}
4 changes: 2 additions & 2 deletions vortex-array/src/scalar_fn/fns/binary/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::arrays::scalar_fn::ScalarFnArrayExt;
use crate::arrays::scalar_fn::ScalarFnArrayView;
use crate::arrow::ArrowSessionExt;
use crate::arrow::Datum;
use crate::arrow::from_arrow_array_with_len;
use crate::arrow::from_arrow_columnar;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::kernel::ExecuteParentKernel;
Expand Down Expand Up @@ -178,7 +178,7 @@ fn arrow_compare_arrays(
}
};

from_arrow_array_with_len(&arrow_array, left.len(), nullable)
from_arrow_columnar(&arrow_array, left.len(), nullable, ctx)
}

pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/scalar_fn/fns/binary/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::IntoArray;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrow::Datum;
use crate::arrow::from_arrow_array_with_len;
use crate::arrow::from_arrow_columnar;
use crate::executor::ExecutionCtx;
use crate::scalar::NumericOperator;

Expand Down Expand Up @@ -48,7 +48,7 @@ pub(crate) fn arrow_numeric(
NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
};

from_arrow_array_with_len(array.as_ref(), len, nullable)
from_arrow_columnar(array.as_ref(), len, nullable, ctx)
}

fn constant_numeric(
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/scalar_fn/fns/like/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use vortex_session::registry::CachedId;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::arrow::Datum;
use crate::arrow::from_arrow_array_with_len;
use crate::arrow::from_arrow_columnar;
use crate::dtype::DType;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
Expand Down Expand Up @@ -237,7 +237,7 @@ pub(crate) fn arrow_like(
(true, true) => arrow_string::like::nilike(&lhs, &rhs)?,
};

from_arrow_array_with_len(&result, len, nullable)
from_arrow_columnar(&result, len, nullable, ctx)
}

/// Variants of the LIKE filter that we know how to turn into a stats pruning predicate.
Expand Down
Loading
Loading