From dd98964bf67043a324205c7da51f71ff277fbc6c Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 27 Oct 2025 23:56:57 +0200 Subject: [PATCH 1/6] fix: `ArrayIter` does not report size hint correctly after advancing from the iterator back this also adds a LOT of tests extracted from (which is how I found that bug): - #8697 --- arrow-array/src/iterator.rs | 930 +++++++++++++++++++++++++++++++++++- 1 file changed, 925 insertions(+), 5 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 6708da3d5dd6..c1026c3ad561 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -44,7 +44,7 @@ use arrow_buffer::NullBuffer; /// [`PrimitiveArray`]: crate::PrimitiveArray /// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html /// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ArrayIter { array: T, logical_nulls: Option, @@ -98,8 +98,8 @@ impl Iterator for ArrayIter { fn size_hint(&self) -> (usize, Option) { ( - self.array.len() - self.current, - Some(self.array.len() - self.current), + self.current_end - self.current, + Some(self.current_end - self.current), ) } } @@ -147,9 +147,14 @@ pub type MapArrayIter<'a> = ArrayIter<&'a MapArray>; pub type GenericListViewArrayIter<'a, O> = ArrayIter<&'a GenericListViewArray>; #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::array::{ArrayRef, BinaryArray, BooleanArray, Int32Array, StringArray}; + use crate::iterator::ArrayIter; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; + use std::sync::Arc; #[test] fn test_primitive_array_iter_round_trip() { @@ -264,4 +269,919 @@ mod tests { // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some(true)); } + + trait SharedBetweenArrayIterAndSliceIter: + ExactSizeIterator> + DoubleEndedIterator> + Clone + { + } + impl> + DoubleEndedIterator>> + SharedBetweenArrayIterAndSliceIter for T + { + } + + fn get_int32_iterator_cases() -> impl Iterator>)> { + let mut rng = StdRng::seed_from_u64(42); + + let no_nulls_and_no_duplicates = (0..10).map(Some).collect::>>(); + let no_nulls_random_values = (0..10) + .map(|_| rng.random::()) + .map(Some) + .collect::>>(); + + let all_nulls = (0..10).map(|_| None).collect::>>(); + let only_start_nulls = (0..10) + .map(|item| if item < 4 { None } else { Some(item) }) + .collect::>>(); + let only_end_nulls = (0..10) + .map(|item| if item > 8 { None } else { Some(item) }) + .collect::>>(); + let only_middle_nulls = (0..10) + .map(|item| { + if (4..=8).contains(&item) && rng.random_bool(0.9) { + None + } else { + Some(item) + } + }) + .collect::>>(); + let random_values_with_random_nulls = (0..10) + .map(|_| { + if rng.random_bool(0.3) { + None + } else { + Some(rng.random::()) + } + }) + .collect::>>(); + + let no_nulls_and_some_duplicates = (0..10) + .map(|item| item % 3) + .map(Some) + .collect::>>(); + let no_nulls_and_all_same_value = + (0..10).map(|_| 9).map(Some).collect::>>(); + let no_nulls_and_continues_duplicates = [0, 0, 0, 1, 1, 2, 2, 2, 2, 3] + .map(Some) + .into_iter() + .collect::>>(); + + let single_null_and_no_duplicates = (0..10) + .map(|item| if item == 4 { None } else { Some(item) }) + .collect::>>(); + let multiple_nulls_and_no_duplicates = (0..10) + .map(|item| if item % 3 == 2 { None } else { Some(item) }) + .collect::>>(); + let continues_nulls_and_no_duplicates = [ + Some(0), + Some(1), + None, + None, + Some(2), + Some(3), + None, + Some(4), + Some(5), + None, + ] + .into_iter() + .collect::>>(); + + [ + no_nulls_and_no_duplicates, + no_nulls_random_values, + no_nulls_and_some_duplicates, + no_nulls_and_all_same_value, + no_nulls_and_continues_duplicates, + all_nulls, + only_start_nulls, + only_end_nulls, + only_middle_nulls, + random_values_with_random_nulls, + single_null_and_no_duplicates, + multiple_nulls_and_no_duplicates, + continues_nulls_and_no_duplicates, + ] + .map(|case| (Int32Array::from(case.clone()), case)) + .into_iter() + } + + trait SetupIter { + fn setup(&self, iter: &mut I); + } + + struct NoSetup; + impl SetupIter for NoSetup { + fn setup(&self, _iter: &mut I) { + // none + } + } + + fn setup_and_assert_cases( + setup_iterator: impl SetupIter, + assert_fn: impl Fn(ArrayIter<&Int32Array>, Copied>>), + ) { + for (array, source) in get_int32_iterator_cases() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + setup_iterator.setup(&mut actual); + setup_iterator.setup(&mut expected); + + assert_fn(actual, expected); + } + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait ArrayIteratorOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(&self, iter: T) -> Self::Output; + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait ArrayIteratorMutateOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(&self, iter: &mut T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_array_iterator_cases(o: O) { + setup_and_assert_cases(NoSetup, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct Next; + impl SetupIter for Next { + fn setup(&self, iter: &mut I) { + iter.next(); + } + } + setup_and_assert_cases(Next, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBack; + impl SetupIter for NextBack { + fn setup(&self, iter: &mut I) { + iter.next_back(); + } + } + + setup_and_assert_cases(NextBack, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextAndBack; + impl SetupIter for NextAndBack { + fn setup(&self, iter: &mut I) { + iter.next(); + iter.next_back(); + } + } + + setup_and_assert_cases(NextAndBack, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLast; + impl SetupIter for NextUntilLast { + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth(len - 2); + } + } + } + setup_and_assert_cases(NextUntilLast, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBackUntilFirst; + impl SetupIter for NextBackUntilFirst { + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth_back(len - 2); + } + } + } + setup_and_assert_cases(NextBackUntilFirst, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextFinish; + impl SetupIter for NextFinish { + fn setup(&self, iter: &mut I) { + iter.nth(iter.len()); + } + } + setup_and_assert_cases(NextFinish, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBackFinish; + impl SetupIter for NextBackFinish { + fn setup(&self, iter: &mut I) { + iter.nth_back(iter.len()); + } + } + setup_and_assert_cases(NextBackFinish, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLastNone; + impl SetupIter for NextUntilLastNone { + fn setup(&self, iter: &mut I) { + let last_null_position = iter.clone().rposition(|item| item.is_none()); + + // move the iterator to the location where there are no nulls anymore + if let Some(last_null_position) = last_null_position { + iter.nth(last_null_position); + } + } + } + setup_and_assert_cases(NextUntilLastNone, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for iter that have no nulls left (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLastSome; + impl SetupIter for NextUntilLastSome { + fn setup(&self, iter: &mut I) { + let last_some_position = iter.clone().rposition(|item| item.is_some()); + + // move the iterator to the location where there are only nulls + if let Some(last_some_position) = last_some_position { + iter.nth(last_some_position); + } + } + } + setup_and_assert_cases(NextUntilLastSome, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for iter that only have nulls left (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_array_iterator_cases_mutate(o: O) { + for (array, source) in get_int32_iterator_cases() { + for i in 0..source.len() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + // calling nth(0) is the same as calling next() + // but we want to get to the ith position so we call nth(i - 1) + if i > 0 { + actual.nth(i - 1); + expected.nth(i - 1); + } + + let current_iterator_values: Vec> = expected.clone().collect(); + + let actual_value = o.get_value(&mut actual); + let expected_value = o.get_value(&mut expected); + + assert_eq!( + actual_value, + expected_value, + "Failed on op {} for iter that advanced to i {i} (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + + let left_over_actual: Vec<_> = actual.clone().collect(); + let left_over_expected: Vec<_> = expected.clone().collect(); + + assert_eq!( + left_over_actual, left_over_expected, + "state after mutable should be the same" + ); + } + } + } + + #[derive(Debug, PartialEq)] + struct CallTrackingAndResult { + result: Result, + calls: Vec, + } + type CallTrackingWithInputType = CallTrackingAndResult>; + type CallTrackingOnly = CallTrackingWithInputType<()>; + + #[test] + fn assert_position() { + struct PositionOp { + reverse: bool, + number_of_false: usize, + } + + impl ArrayIteratorMutateOp for PositionOp { + type Output = CallTrackingWithInputType>; + fn name(&self) -> String { + if self.reverse { + format!("rposition with {} false returned", self.number_of_false) + } else { + format!("position with {} false returned", self.number_of_false) + } + } + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let position_result = if self.reverse { + iter.rposition(|item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }) + } else { + iter.position(|item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }) + }; + + CallTrackingAndResult { + result: position_result, + calls: items, + } + } + } + + for reverse in [false, true] { + for number_of_false in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(PositionOp { + reverse, + number_of_false, + }); + } + } + } + + #[test] + fn assert_nth() { + setup_and_assert_cases(NoSetup, |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }); + } + + #[test] + fn assert_nth_back() { + setup_and_assert_cases(NoSetup, |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }); + } + + #[test] + fn assert_last() { + for (array, source) in get_int32_iterator_cases() { + let mut actual_forward = ArrayIter::new(&array); + let mut expected_forward = source.iter().copied(); + + for _ in 0..source.len() + 1 { + { + let actual_forward_clone = actual_forward.clone(); + let expected_forward_clone = expected_forward.clone(); + + assert_eq!(actual_forward_clone.last(), expected_forward_clone.last()); + } + + actual_forward.next(); + expected_forward.next(); + } + + let mut actual_backward = ArrayIter::new(&array); + let mut expected_backward = source.iter().copied(); + for _ in 0..source.len() + 1 { + { + assert_eq!( + actual_backward.clone().last(), + expected_backward.clone().last() + ); + } + + actual_backward.next_back(); + expected_backward.next_back(); + } + } + } + + #[test] + fn assert_for_each() { + struct ForEachOp; + + impl ArrayIteratorOp for ForEachOp { + type Output = CallTrackingOnly; + + fn name(&self) -> String { + "for_each".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + iter.for_each(|item| { + items.push(item); + }); + + CallTrackingAndResult { + calls: items, + result: (), + } + } + } + + assert_array_iterator_cases(ForEachOp) + } + + #[test] + fn assert_fold() { + struct FoldOp { + reverse: bool, + } + + #[derive(Debug, PartialEq)] + struct CallArgs { + acc: Option, + item: Option, + } + + impl ArrayIteratorOp for FoldOp { + type Output = CallTrackingAndResult, CallArgs>; + + fn name(&self) -> String { + if self.reverse { + "rfold".to_string() + } else { + "fold".to_string() + } + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let result = if self.reverse { + iter.rfold(Some(1), |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }) + } else { + #[allow(clippy::manual_try_fold)] + iter.fold(Some(1), |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }) + }; + + CallTrackingAndResult { + calls: items, + result, + } + } + } + + assert_array_iterator_cases(FoldOp { reverse: false }); + assert_array_iterator_cases(FoldOp { reverse: true }); + } + + #[test] + fn assert_count() { + struct CountOp; + + impl ArrayIteratorOp for CountOp { + type Output = usize; + + fn name(&self) -> String { + "count".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + iter.count() + } + } + + assert_array_iterator_cases(CountOp) + } + + #[test] + fn assert_any() { + struct AnyOp { + false_count: usize, + } + + impl ArrayIteratorMutateOp for AnyOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("any with {} false returned", self.false_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.any(|item| { + items.push(item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AnyOp { false_count }); + } + } + + #[test] + fn assert_all() { + struct AllOp { + true_count: usize, + } + + impl ArrayIteratorMutateOp for AllOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("all with {} false returned", self.true_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.all(|item| { + items.push(item); + + if count < self.true_count { + count += 1; + true + } else { + false + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for true_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AllOp { true_count }); + } + } + + #[test] + fn assert_find() { + struct FindOp { + reverse: bool, + false_count: usize, + } + + impl ArrayIteratorMutateOp for FindOp { + type Output = CallTrackingWithInputType>>; + + fn name(&self) -> String { + if self.reverse { + format!("rfind with {} false returned", self.false_count) + } else { + format!("find with {} false returned", self.false_count) + } + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let position_result = if self.reverse { + iter.rfind(|item| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }) + } else { + iter.find(|item| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }) + }; + + CallTrackingWithInputType { + calls: items, + result: position_result, + } + } + } + + for reverse in [false, true] { + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindOp { + reverse, + false_count, + }); + } + } + } + + #[test] + fn assert_find_map() { + struct FindMapOp { + number_of_nones: usize, + } + + impl ArrayIteratorMutateOp for FindMapOp { + type Output = CallTrackingWithInputType>; + + fn name(&self) -> String { + format!("find_map with {} None returned", self.number_of_nones) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let result = iter.find_map(|item| { + items.push(item); + + if count < self.number_of_nones { + count += 1; + None + } else { + Some("found it") + } + }); + + CallTrackingAndResult { + result, + calls: items, + } + } + } + + for number_of_nones in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindMapOp { number_of_nones }); + } + } + + #[test] + fn assert_partition() { + struct PartitionOp) -> bool> { + description: &'static str, + predicate: F, + } + + #[derive(Debug, PartialEq)] + struct PartitionResult { + left: Vec>, + right: Vec>, + } + + impl) -> bool> ArrayIteratorOp for PartitionOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("partition by {}", self.description) + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = vec![]; + + let mut index = 0; + + let (left, right) = iter.partition(|item| { + items.push(*item); + + let res = (self.predicate)(index, item); + + index += 1; + res + }); + + CallTrackingAndResult { + result: PartitionResult { left, right }, + calls: items, + } + } + } + + assert_array_iterator_cases(PartitionOp { + description: "None on one side and Some(*) on the other", + predicate: |_, item| item.is_none(), + }); + + assert_array_iterator_cases(PartitionOp { + description: "all true", + predicate: |_, _| true, + }); + + assert_array_iterator_cases(PartitionOp { + description: "all false", + predicate: |_, _| false, + }); + + let random_values = (0..100).map(|_| rand::random_bool(0.5)).collect::>(); + assert_array_iterator_cases(PartitionOp { + description: "random", + predicate: |index, _| random_values[index % random_values.len()], + }); + } } From ae5e9618619a91d51a03ed51571fab6e912a49ca Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 28 Oct 2025 00:19:05 +0200 Subject: [PATCH 2/6] cleanup tests --- arrow-array/src/iterator.rs | 101 +++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index c1026c3ad561..f96d0158768e 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -391,38 +391,42 @@ mod tests { } } - /// Trait representing an operation on a BitIterator + /// Trait representing an operation on a [`ArrayIter`] /// that can be compared against a slice iterator - trait ArrayIteratorOp { - /// What the operation returns (e.g. Option for last/max, usize for count, etc) + /// + /// this is for consuming operations (e.g. `count`, `last`, etc) + trait ConsumingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) type Output: PartialEq + Debug; /// The name of the operation, used for error messages fn name(&self) -> String; /// Get the value of the operation for the provided iterator - /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result fn get_value(&self, iter: T) -> Self::Output; } - /// Trait representing an operation on a BitIterator - /// that can be compared against a slice iterator - trait ArrayIteratorMutateOp { - /// What the operation returns (e.g. Option for last/max, usize for count, etc) + /// Trait representing an operation on a [`ArrayIter`] + /// that can be compared against a slice iterator. + /// + /// This is for mutating operations (e.g. `position`, `any`, `find`, etc) + trait MutatingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) type Output: PartialEq + Debug; /// The name of the operation, used for error messages fn name(&self) -> String; /// Get the value of the operation for the provided iterator - /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result fn get_value(&self, iter: &mut T) -> Self::Output; } /// Helper function that will assert that the provided operation - /// produces the same result for both BitIterator and slice iterator + /// produces the same result for both [`ArrayIter`] and slice iterator /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) - fn assert_array_iterator_cases(o: O) { + fn assert_array_iterator_cases(o: O) { setup_and_assert_cases(NoSetup, |actual, expected| { let current_iterator_values: Vec> = expected.clone().collect(); assert_eq!( @@ -607,42 +611,43 @@ mod tests { } /// Helper function that will assert that the provided operation - /// produces the same result for both BitIterator and slice iterator + /// produces the same result for both [`ArrayIter`] and slice iterator /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) - fn assert_array_iterator_cases_mutate(o: O) { - for (array, source) in get_int32_iterator_cases() { - for i in 0..source.len() { - let mut actual = ArrayIter::new(&array); - let mut expected = source.iter().copied(); - - // calling nth(0) is the same as calling next() - // but we want to get to the ith position so we call nth(i - 1) - if i > 0 { - actual.nth(i - 1); - expected.nth(i - 1); - } + /// + /// this is different from [`assert_array_iterator_cases`] as this also check that the state after the call is correct + /// to make sure we don't leave the iterator in incorrect state + fn assert_array_iterator_cases_mutate(o: O) { + struct Adapter { + o: O, + } - let current_iterator_values: Vec> = expected.clone().collect(); + #[derive(Debug, PartialEq)] + struct AdapterOutput { + value: Value, + /// collect on the iterator after running the operation + leftover: Vec>, + } - let actual_value = o.get_value(&mut actual); - let expected_value = o.get_value(&mut expected); + impl ConsumingArrayIteratorOp for Adapter { + type Output = AdapterOutput; - assert_eq!( - actual_value, - expected_value, - "Failed on op {} for iter that advanced to i {i} (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); + fn name(&self) -> String { + self.o.name() + } + + fn get_value( + &self, + mut iter: T, + ) -> Self::Output { + let value = self.o.get_value(&mut iter); - let left_over_actual: Vec<_> = actual.clone().collect(); - let left_over_expected: Vec<_> = expected.clone().collect(); + let leftover: Vec<_> = iter.collect(); - assert_eq!( - left_over_actual, left_over_expected, - "state after mutable should be the same" - ); + AdapterOutput { value, leftover } } } + + assert_array_iterator_cases(Adapter { o }) } #[derive(Debug, PartialEq)] @@ -660,7 +665,7 @@ mod tests { number_of_false: usize, } - impl ArrayIteratorMutateOp for PositionOp { + impl MutatingArrayIteratorOp for PositionOp { type Output = CallTrackingWithInputType>; fn name(&self) -> String { if self.reverse { @@ -830,7 +835,7 @@ mod tests { fn assert_for_each() { struct ForEachOp; - impl ArrayIteratorOp for ForEachOp { + impl ConsumingArrayIteratorOp for ForEachOp { type Output = CallTrackingOnly; fn name(&self) -> String { @@ -866,7 +871,7 @@ mod tests { item: Option, } - impl ArrayIteratorOp for FoldOp { + impl ConsumingArrayIteratorOp for FoldOp { type Output = CallTrackingAndResult, CallArgs>; fn name(&self) -> String { @@ -910,7 +915,7 @@ mod tests { fn assert_count() { struct CountOp; - impl ArrayIteratorOp for CountOp { + impl ConsumingArrayIteratorOp for CountOp { type Output = usize; fn name(&self) -> String { @@ -931,7 +936,7 @@ mod tests { false_count: usize, } - impl ArrayIteratorMutateOp for AnyOp { + impl MutatingArrayIteratorOp for AnyOp { type Output = CallTrackingWithInputType; fn name(&self) -> String { @@ -974,7 +979,7 @@ mod tests { true_count: usize, } - impl ArrayIteratorMutateOp for AllOp { + impl MutatingArrayIteratorOp for AllOp { type Output = CallTrackingWithInputType; fn name(&self) -> String { @@ -1018,7 +1023,7 @@ mod tests { false_count: usize, } - impl ArrayIteratorMutateOp for FindOp { + impl MutatingArrayIteratorOp for FindOp { type Output = CallTrackingWithInputType>>; fn name(&self) -> String { @@ -1084,7 +1089,7 @@ mod tests { number_of_nones: usize, } - impl ArrayIteratorMutateOp for FindMapOp { + impl MutatingArrayIteratorOp for FindMapOp { type Output = CallTrackingWithInputType>; fn name(&self) -> String { @@ -1135,7 +1140,7 @@ mod tests { right: Vec>, } - impl) -> bool> ArrayIteratorOp for PartitionOp { + impl) -> bool> ConsumingArrayIteratorOp for PartitionOp { type Output = CallTrackingWithInputType; fn name(&self) -> String { From 7f8f788d0b6fcf0840b9a7099479e2a77cd5b93b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 28 Oct 2025 22:55:51 +0200 Subject: [PATCH 3/6] share assertion --- arrow-array/src/iterator.rs | 168 +++++++++++++++--------------------- 1 file changed, 69 insertions(+), 99 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index f96d0158768e..e807c6e8f9f4 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -366,11 +366,15 @@ mod tests { } trait SetupIter { + fn description(&self) -> String; fn setup(&self, iter: &mut I); } struct NoSetup; impl SetupIter for NoSetup { + fn description(&self) -> String { + "no setup".to_string() + } fn setup(&self, _iter: &mut I) { // none } @@ -391,6 +395,29 @@ mod tests { } } + fn setup_and_assert_cases_on_single_operation( + o: &impl ConsumingArrayIteratorOp, + setup_iterator: impl SetupIter, + ) { + for (array, source) in get_int32_iterator_cases() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + setup_iterator.setup(&mut actual); + setup_iterator.setup(&mut expected); + + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for {} (left actual, right expected) ({current_iterator_values:?})", + setup_iterator.description(), + o.name() + ); + } + } + /// Trait representing an operation on a [`ArrayIter`] /// that can be compared against a slice iterator /// @@ -427,72 +454,51 @@ mod tests { /// produces the same result for both [`ArrayIter`] and slice iterator /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) fn assert_array_iterator_cases(o: O) { - setup_and_assert_cases(NoSetup, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NoSetup); struct Next; impl SetupIter for Next { + fn description(&self) -> String { + "new iter after consuming 1 element from the start".to_string() + } fn setup(&self, iter: &mut I) { iter.next(); } } - setup_and_assert_cases(Next, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, Next); struct NextBack; impl SetupIter for NextBack { + fn description(&self) -> String { + "new iter after consuming 1 element from the end".to_string() + } + fn setup(&self, iter: &mut I) { iter.next_back(); } } - setup_and_assert_cases(NextBack, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextBack); struct NextAndBack; impl SetupIter for NextAndBack { + fn description(&self) -> String { + "new iter after consuming 1 element from start and end".to_string() + } + fn setup(&self, iter: &mut I) { iter.next(); iter.next_back(); } } - setup_and_assert_cases(NextAndBack, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextAndBack); struct NextUntilLast; impl SetupIter for NextUntilLast { + fn description(&self) -> String { + "new iter after consuming all from the start but 1".to_string() + } fn setup(&self, iter: &mut I) { let len = iter.len(); if len > 1 { @@ -500,19 +506,14 @@ mod tests { } } } - setup_and_assert_cases(NextUntilLast, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextUntilLast); struct NextBackUntilFirst; impl SetupIter for NextBackUntilFirst { + fn description(&self) -> String { + "new iter after consuming all from the end but 1".to_string() + } + fn setup(&self, iter: &mut I) { let len = iter.len(); if len > 1 { @@ -520,53 +521,35 @@ mod tests { } } } - setup_and_assert_cases(NextBackUntilFirst, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextBackUntilFirst); struct NextFinish; impl SetupIter for NextFinish { + fn description(&self) -> String { + "new iter after consuming all from the start".to_string() + } fn setup(&self, iter: &mut I) { iter.nth(iter.len()); } } - setup_and_assert_cases(NextFinish, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextFinish); struct NextBackFinish; impl SetupIter for NextBackFinish { + fn description(&self) -> String { + "new iter after consuming all from the end".to_string() + } fn setup(&self, iter: &mut I) { iter.nth_back(iter.len()); } } - setup_and_assert_cases(NextBackFinish, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextBackFinish); struct NextUntilLastNone; impl SetupIter for NextUntilLastNone { + fn description(&self) -> String { + "new iter that have no nulls left".to_string() + } fn setup(&self, iter: &mut I) { let last_null_position = iter.clone().rposition(|item| item.is_none()); @@ -576,19 +559,13 @@ mod tests { } } } - setup_and_assert_cases(NextUntilLastNone, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for iter that have no nulls left (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextUntilLastNone); struct NextUntilLastSome; impl SetupIter for NextUntilLastSome { + fn description(&self) -> String { + "iter that only have nulls left".to_string() + } fn setup(&self, iter: &mut I) { let last_some_position = iter.clone().rposition(|item| item.is_some()); @@ -598,16 +575,7 @@ mod tests { } } } - setup_and_assert_cases(NextUntilLastSome, |actual, expected| { - let current_iterator_values: Vec> = expected.clone().collect(); - - assert_eq!( - o.get_value(actual), - o.get_value(expected), - "Failed on op {} for iter that only have nulls left (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); - }); + setup_and_assert_cases_on_single_operation(&o, NextUntilLastSome); } /// Helper function that will assert that the provided operation @@ -641,6 +609,7 @@ mod tests { ) -> Self::Output { let value = self.o.get_value(&mut iter); + // Get the rest of the iterator to make sure we leave the iterator in a valid state let leftover: Vec<_> = iter.collect(); AdapterOutput { value, leftover } @@ -674,6 +643,7 @@ mod tests { format!("position with {} false returned", self.number_of_false) } } + fn get_value( &self, iter: &mut T, From 60b1317966c84814bc3fe0926dd6dc274011d444 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:02:39 +0200 Subject: [PATCH 4/6] add more comments --- arrow-array/src/iterator.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index e807c6e8f9f4..911c763417fe 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -431,6 +431,10 @@ mod tests { /// Get the value of the operation for the provided iterator /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result + /// + /// Example implementation: + /// 1. for `last` it will be the last value + /// 2. for `count` it will be the returned length fn get_value(&self, iter: T) -> Self::Output; } @@ -447,6 +451,10 @@ mod tests { /// Get the value of the operation for the provided iterator /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result + /// + /// Example implementation: + /// 1. for `for_each` it will be the iterator element that the function was called with + /// 2. for `fold` it will be the accumulator and the iterator element from each call, as well as the final result fn get_value(&self, iter: &mut T) -> Self::Output; } From f36bbbdcdf4a400eeb02bf0eb7ba5f90de4ea53a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:06:46 +0200 Subject: [PATCH 5/6] reduce duplication --- arrow-array/src/iterator.rs | 84 +++++++++++++++---------------------- 1 file changed, 34 insertions(+), 50 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 911c763417fe..9fe4ef2202d3 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -660,28 +660,21 @@ mod tests { let mut count = 0; + let cb = |item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }; + let position_result = if self.reverse { - iter.rposition(|item| { - items.push(item); - - if count < self.number_of_false { - count += 1; - false - } else { - true - } - }) + iter.rposition(cb) } else { - iter.position(|item| { - items.push(item); - - if count < self.number_of_false { - count += 1; - false - } else { - true - } - }) + iter.position(cb) }; CallTrackingAndResult { @@ -863,19 +856,17 @@ mod tests { fn get_value(&self, iter: T) -> Self::Output { let mut items = Vec::with_capacity(iter.len()); - let result = if self.reverse { - iter.rfold(Some(1), |acc, item| { - items.push(CallArgs { item, acc }); + let cb = |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }; - item.map(|val| val + 100) - }) + let result = if self.reverse { + iter.rfold(Some(1), cb) } else { #[allow(clippy::manual_try_fold)] - iter.fold(Some(1), |acc, item| { - items.push(CallArgs { item, acc }); - - item.map(|val| val + 100) - }) + iter.fold(Some(1), cb) }; CallTrackingAndResult { @@ -1020,28 +1011,21 @@ mod tests { let mut count = 0; + let cb = |item: &Option| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }; + let position_result = if self.reverse { - iter.rfind(|item| { - items.push(*item); - - if count < self.false_count { - count += 1; - false - } else { - true - } - }) + iter.rfind(cb) } else { - iter.find(|item| { - items.push(*item); - - if count < self.false_count { - count += 1; - false - } else { - true - } - }) + iter.find(cb) }; CallTrackingWithInputType { From 72904f7c2fcb6dcf32dbbc8d71bad7e413b41620 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:10:31 +0200 Subject: [PATCH 6/6] cleanup --- arrow-array/src/iterator.rs | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 9fe4ef2202d3..e72b259ef049 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -152,8 +152,6 @@ mod tests { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::fmt::Debug; - use std::iter::Copied; - use std::slice::Iter; use std::sync::Arc; #[test] @@ -380,21 +378,6 @@ mod tests { } } - fn setup_and_assert_cases( - setup_iterator: impl SetupIter, - assert_fn: impl Fn(ArrayIter<&Int32Array>, Copied>>), - ) { - for (array, source) in get_int32_iterator_cases() { - let mut actual = ArrayIter::new(&array); - let mut expected = source.iter().copied(); - - setup_iterator.setup(&mut actual); - setup_iterator.setup(&mut expected); - - assert_fn(actual, expected); - } - } - fn setup_and_assert_cases_on_single_operation( o: &impl ConsumingArrayIteratorOp, setup_iterator: impl SetupIter, @@ -412,8 +395,8 @@ mod tests { o.get_value(actual), o.get_value(expected), "Failed on op {} for {} (left actual, right expected) ({current_iterator_values:?})", + o.name(), setup_iterator.description(), - o.name() ); } } @@ -696,7 +679,9 @@ mod tests { #[test] fn assert_nth() { - setup_and_assert_cases(NoSetup, |actual, expected| { + for (array, source) in get_int32_iterator_cases() { + let actual = ArrayIter::new(&array); + let expected = source.iter().copied(); { let mut actual = actual.clone(); let mut expected = expected.clone(); @@ -728,12 +713,14 @@ mod tests { assert_eq!(actual_val, expected_val, "Failed on nth(2)"); } } - }); + } } #[test] fn assert_nth_back() { - setup_and_assert_cases(NoSetup, |actual, expected| { + for (array, source) in get_int32_iterator_cases() { + let actual = ArrayIter::new(&array); + let expected = source.iter().copied(); { let mut actual = actual.clone(); let mut expected = expected.clone(); @@ -765,7 +752,7 @@ mod tests { assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); } } - }); + } } #[test]