Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,17 @@ pub fn filter_record_batch(
predicate: &BooleanArray,
) -> Result<RecordBatch, ArrowError> {
let mut filter_builder = FilterBuilder::new(predicate);
if record_batch.num_columns() > 1 {
let num_cols = record_batch.num_columns();
if num_cols > 1
|| (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type()))
{
// Only optimize if filtering more than one column
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says to optimize only when filtering >1 column, but the condition also triggers for a single column when the first column’s type spans multiple arrays; consider updating the comment to reflect this (also applies to the following line).

🤖 React with 👍 or 👎 to let us know if the comment was useful.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:good-to-have; category:documentation; feedback:The AI agent is correct that the old comment became incomplete with the new changes in this PR and it would be good to update it

// Otherwise, the overhead of optimization can be more than the benefit
filter_builder = filter_builder.optimize();
}
let filter = filter_builder.build();

let filtered_arrays = record_batch
.columns()
.iter()
.map(|a| filter_array(a, &filter))
.collect::<Result<Vec<_>, _>>()?;
let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
filter.filter_record_batch(record_batch)
}

/// A builder to construct [`FilterPredicate`]
Expand Down Expand Up @@ -300,6 +297,31 @@ impl FilterPredicate {
filter_array(values, self)
}

/// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
/// [`FilterPredicate`].
///
/// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
pub fn filter_record_batch(
&self,
record_batch: &RecordBatch,
) -> Result<RecordBatch, ArrowError> {
let filtered_arrays = record_batch
.columns()
.iter()
.map(|a| filter_array(a, self))
.collect::<Result<Vec<_>, _>>()?;

// SAFETY: we know that the set of filtered arrays will match the schema of the original
// record batch
unsafe {
Ok(RecordBatch::new_unchecked(
record_batch.schema(),
filtered_arrays,
self.count,
))
}
}

/// Number of rows being selected based on this [`FilterPredicate`]
pub fn count(&self) -> usize {
self.count
Expand Down
Loading