diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index dace2bab728f..fbbc1929107d 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -173,20 +173,17 @@ pub fn filter_record_batch( predicate: &BooleanArray, ) -> Result { let mut filter_builder = FilterBuilder::new(predicate); - if record_batch.num_columns() > 1 { + let num_cols = record_batch.num_columns(); + if num_cols > 1 + || (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type())) + { // Only optimize if filtering more than one column // Otherwise, the overhead of optimization can be more than the benefit filter_builder = filter_builder.optimize(); } let filter = filter_builder.build(); - let filtered_arrays = record_batch - .columns() - .iter() - .map(|a| filter_array(a, &filter)) - .collect::, _>>()?; - let options = RecordBatchOptions::default().with_row_count(Some(filter.count())); - RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options) + filter.filter_record_batch(record_batch) } /// A builder to construct [`FilterPredicate`] @@ -300,6 +297,31 @@ impl FilterPredicate { filter_array(values, self) } + /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this + /// [`FilterPredicate`]. + /// + /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`]. + pub fn filter_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result { + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_array(a, self)) + .collect::, _>>()?; + + // SAFETY: we know that the set of filtered arrays will match the schema of the original + // record batch + unsafe { + Ok(RecordBatch::new_unchecked( + record_batch.schema(), + filtered_arrays, + self.count, + )) + } + } + /// Number of rows being selected based on this [`FilterPredicate`] pub fn count(&self) -> usize { self.count