diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 7258a9b3bcd..d7375cf6de6 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -621,6 +621,10 @@ impl RowAddrTreeMap { Ok(Self { inner }) } + + pub fn fragments(&self) -> Vec { + self.inner.keys().cloned().collect() + } } impl std::ops::BitOr for RowAddrTreeMap { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index c6c9533ea44..d4eba7bf684 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1832,6 +1832,22 @@ impl Dataset { .collect() } + /// Prunes dataset fragments using scalar indices for the given filter expression. + /// + /// This returns the subset of manifest fragments that still need to be scanned, + /// in manifest order. Fragments not covered by the participating scalar indices + /// are always retained, and if the filter does not yield a scalar index query + /// (or the index result cannot safely exclude fragments), this method is effectively + /// a no-op and returns all fragments. + pub async fn prune_fragments(&self, filter: &str) -> Result> { + Scanner::scalar_indexed_prune_fragments( + Arc::new(self.clone()), + filter, + self.manifest.fragments.clone(), + ) + .await + } + pub fn get_fragment(&self, fragment_id: usize) -> Option { let dataset = Arc::new(self.clone()); let fragment = self diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index a24b0032bf6..6d568946c2b 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashSet; use std::ops::Range; use std::pin::Pin; use std::sync::{Arc, LazyLock}; @@ -3970,6 +3971,113 @@ impl Scanner { Ok(format!("{}", display.indent(verbose))) } + + /// Prune a list of fragments using scalar indices for a given filter. + /// + /// This helper builds a temporary [`Scanner`] over `dataset`, plans `filter` + /// with scalar index support and, when possible, evaluates the scalar index + /// expression to determine which fragments contain any candidate rows. + /// + /// Returns the subset of `fragments` that still need to be scanned: + /// fragments that have at least one candidate row according to the index + /// result, plus any fragments that are not fully covered by all indices. + /// Fragments outside the index coverage are never dropped. + /// + /// # Notes + /// + /// - Inputs: + /// - `dataset`: logical [`Dataset`] used to plan and evaluate the scalar index + /// expression. + /// - `filter`: SQL-like predicate string used for planning; it is not evaluated + /// as a full scan in this helper. + /// - `fragments`: manifest [`Fragment`]s considered for pruning. + /// - Pruning is driven only by scalar index results with *exact* or *at-most* + /// semantics. Results with *at-least* semantics cannot safely exclude any + /// fragment, so all `fragments` are returned unchanged in that case. + /// - When the index result is an allow-list [`RowAddrMask`], fragments that are + /// fully covered by the participating indices and have no allowed rows in the + /// mask are pruned. This never drops fragments that might still satisfy + /// `filter`. + /// - When the index result is a block-list [`RowAddrMask`], only fragments that + /// are both fully covered by the indices and fully blocked in the mask can be + /// pruned. Partially blocked fragments, and all fragments not covered by every + /// index, are always kept to avoid false negatives. + /// - This helper only performs scalar index planning and evaluation; it does not + /// build or execute a full scan plan. + pub async fn scalar_indexed_prune_fragments( + dataset: Arc, + filter: &str, + fragments: Arc>, + ) -> Result> { + let mut scanner = Self::new(dataset.clone()); + + scanner.filter(filter)?; + let filter_plan = scanner.create_filter_plan(true).await?; + + if let Some(index_expr) = filter_plan.expr_filter_plan.index_query.as_ref() { + // Partition fragments into those fully covered by all scalar indices and + // those that are not. + let (covered_frags, missing_frags) = scanner + .partition_frags_by_coverage(index_expr, fragments.clone()) + .await?; + + // Evaluate the scalar index expression to obtain a row-address mask + // over the covered fragments. + let expr_result = index_expr + .evaluate(dataset.as_ref(), &NoOpMetricsCollector) + .await?; + + match expr_result { + IndexExprResult::Exact(mask) | IndexExprResult::AtMost(mask) => match mask { + RowAddrMask::AllowList(map) => { + let allow_fragids: HashSet = map.fragments().into_iter().collect(); + + // Among fully covered fragments, keep only those with at least one allowed row. + let mut allow_frags: Vec = covered_frags + .clone() + .iter() + .filter(|f| allow_fragids.contains(&(f.id as u32))) + .cloned() + .collect(); + + // Always keep fragments that are not fully covered by the indices. + allow_frags.extend(missing_frags); + Ok(allow_frags) + } + + RowAddrMask::BlockList(map) => { + if map.is_empty() { + // No fragment is blocked by the mask; nothing can be pruned. + Ok(fragments.to_vec()) + } else { + let blocked_fragids: HashSet = + map.fragments().into_iter().collect(); + + // Fragments that are not blocked at all or only partially blocked still + // need to be scanned. + let mut allow_frags: Vec = covered_frags + .clone() + .iter() + .filter(|f| { + !blocked_fragids.contains(&(f.id as u32)) + || map.get_fragment_bitmap(f.id as u32).is_some() + }) + .cloned() + .collect(); + + // Always keep fragments that are not fully covered by the indices. + allow_frags.extend(missing_frags); + Ok(allow_frags) + } + } + }, + + IndexExprResult::AtLeast(_) => Ok(fragments.to_vec()), + } + } else { + Ok(fragments.to_vec()) + } + } } // Search over all indexed fields including nested ones, collecting columns that have an diff --git a/rust/lance/src/dataset/tests/dataset_scanner.rs b/rust/lance/src/dataset/tests/dataset_scanner.rs index 9fce5f6d2ca..989438fd3f1 100644 --- a/rust/lance/src/dataset/tests/dataset_scanner.rs +++ b/rust/lance/src/dataset/tests/dataset_scanner.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashSet; use std::sync::Arc; use std::vec; @@ -17,12 +18,16 @@ use lance_arrow::SchemaExt; use lance_index::scalar::inverted::{ query::PhraseQuery, tokenizer::InvertedIndexParams, SCORE_FIELD, }; -use lance_index::scalar::FullTextSearchQuery; +use lance_index::scalar::{BuiltinIndexType, FullTextSearchQuery, ScalarIndexParams}; use lance_index::{vector::DIST_COL, DatasetIndexExt, IndexType}; use lance_linalg::distance::MetricType; +use crate::dataset::scanner::test_dataset::TestVectorDataset; use crate::dataset::scanner::{DatasetRecordBatchStream, QueryFilter}; +use crate::dataset::WriteParams; use crate::Dataset; +use lance_core::utils::tempfile::TempStrDir; +use lance_encoding::version::LanceFileVersion; use lance_index::scalar::inverted::query::FtsQuery; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; @@ -465,3 +470,350 @@ async fn check_results( .unwrap(); assert_eq!(ids.values(), expected_ids); } + +#[tokio::test] +async fn test_prune_fragments_without_scalar_index_returns_all() { + // Build a dataset with 5 fragments of 10 rows each: i = [0, 1, ..., 49]. + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..50))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let write_params = WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }; + let dataset = Dataset::write(reader, &test_uri, Some(write_params)) + .await + .unwrap(); + + let original_fragments = dataset.fragments().clone(); + + // Without a scalar index, pruning should be a no-op and return all fragments. + let pruned = dataset.prune_fragments("i >= 30").await.unwrap(); + + std::assert_eq!(pruned.len(), original_fragments.len()); + let original_ids: Vec = original_fragments.iter().map(|f| f.id).collect(); + let pruned_ids: Vec = pruned.iter().map(|f| f.id).collect(); + std::assert_eq!(pruned_ids, original_ids); +} + +#[tokio::test] +async fn test_prune_fragments_with_scalar_index_prunes_non_matching_fragments() { + // Dataset with 5 fragments of 10 rows each: i = [0, 1, ..., 49]. + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..50))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let write_params = WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }; + let mut dataset = Dataset::write(reader, &test_uri, Some(write_params)) + .await + .unwrap(); + + // Create a scalar index on i so all current fragments are indexed. + dataset + .create_index( + &["i"], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let fragments = dataset.fragments().clone(); + std::assert_eq!(fragments.len(), 5); + + // For filter i >= 30, all matching rows live in the last two fragments. + let expected_tail_ids: Vec = fragments[fragments.len() - 2..] + .iter() + .map(|f| f.id) + .collect(); + + let pruned = dataset.prune_fragments("i >= 30").await.unwrap(); + let pruned_ids: Vec = pruned.iter().map(|f| f.id).collect(); + + std::assert_eq!(pruned_ids, expected_tail_ids); +} + +#[tokio::test] +async fn test_prune_fragments_with_scalar_index_and_mixed_or_filter_is_noop() { + // Multi-column dataset with predictable fragment boundaries: 5 fragments + // of 10 rows each, columns col_a, col_b, col_c. + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("col_a", DataType::Int32, false), + ArrowField::new("col_b", DataType::Int32, false), + ArrowField::new("col_c", DataType::Int32, false), + ])); + + // col_a: 0..50 (monotonic sequence for range queries) + let col_a = Int32Array::from_iter_values(0..50); + // col_b: first fragment has small values (< 10) so rows there can only + // match the filter via the non-indexed side `col_b < 10`; later fragments + // have large values. + let col_b = Int32Array::from_iter_values((0..50).map(|i| if i < 10 { i } else { 100 + i })); + // col_c: arbitrary third column, no index. + let col_c = Int32Array::from_iter_values((0..50).map(|i| i * 2)); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(col_a), Arc::new(col_b), Arc::new(col_c)], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let write_params = WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }; + let mut dataset = Dataset::write(reader, &test_uri, Some(write_params)) + .await + .unwrap(); + + // Create a scalar index only on col_a. + dataset + .create_index( + &["col_a"], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let fragments = dataset.fragments().clone(); + std::assert_eq!(fragments.len(), 5); + + // For filter `col_a >= 10 OR col_b < 10`, only `col_a` is indexable. The + // planner should treat this OR as mixed indexability and fall back to a + // refine-only filter plan, so scalar-index-based pruning becomes a no-op + // and all fragments are retained. + let pruned = dataset + .prune_fragments("col_a >= 10 OR col_b < 10") + .await + .unwrap(); + let original_ids: Vec = fragments.iter().map(|f| f.id).collect(); + let pruned_ids: Vec = pruned.iter().map(|f| f.id).collect(); + std::assert_eq!(pruned_ids, original_ids); +} + +#[tokio::test] +async fn test_prune_fragments_keeps_fragments_outside_index_coverage() { + // Use TestVectorDataset to construct a multi-fragment dataset with an Int32 column "i". + // This matches the pattern used elsewhere in dataset_index.rs for multi-fragment tests. + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + + // Build a scalar index on i covering the initial fragments. + test_ds.make_scalar_index().await.unwrap(); + + let before_fragments = test_ds.dataset.fragments().clone(); + let before_ids: HashSet = before_fragments.iter().map(|f| f.id).collect(); + + // Append new data so the new fragment is not covered by the existing index. + test_ds.append_new_data().await.unwrap(); + let all_fragments = test_ds.dataset.fragments().clone(); + let new_fragments: Vec<_> = all_fragments + .iter() + .filter(|f| !before_ids.contains(&f.id)) + .collect(); + + // Sanity check: we expect exactly one newly appended fragment. + std::assert_eq!(new_fragments.len(), 1); + let new_fragment_id = new_fragments[0].id; + + // Use a filter that only matches early rows, which live in the original fragments. + let pruned = test_ds.dataset.prune_fragments("i < 10").await.unwrap(); + let pruned_ids: HashSet = pruned.iter().map(|f| f.id).collect(); + + // Fragments without index coverage must not be pruned. + assert!(pruned_ids.contains(&new_fragment_id)); +} + +#[tokio::test] +async fn test_prune_fragments_with_zonemap_scalar_index_prunes_non_matching_fragments() { + // Dataset with 5 fragments of 10 rows each: z = [0, 1, ..., 49]. + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "z", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..50))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let write_params = WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }; + let mut dataset = Dataset::write(reader, &test_uri, Some(write_params)) + .await + .unwrap(); + + // Create a ZoneMap scalar index on z so all current fragments are indexed. + let zonemap_params = ScalarIndexParams::for_builtin(BuiltinIndexType::ZoneMap); + dataset + .create_index(&["z"], IndexType::Scalar, None, &zonemap_params, true) + .await + .unwrap(); + + let fragments = dataset.fragments().clone(); + std::assert_eq!(fragments.len(), 5); + + // For filter z >= 30, all matching rows live in the last two fragments. + // ZoneMap returns an AtMost allow-list mask, and scalar_indexed_prune_fragments + // prunes covered fragments that have no allowed rows while keeping uncovered + // fragments, so we expect only the tail fragments to remain. + let expected_tail_ids: Vec = fragments[fragments.len() - 2..] + .iter() + .map(|f| f.id) + .collect(); + + let pruned = dataset.prune_fragments("z >= 30").await.unwrap(); + let pruned_ids: Vec = pruned.iter().map(|f| f.id).collect(); + + std::assert_eq!(pruned_ids, expected_tail_ids); +} + +#[tokio::test] +async fn test_prune_fragments_with_scalar_index_blocklist_partial_keeps_all_fragments() { + // Dataset with 5 fragments of 10 rows each: i = [0, 1, ..., 49]. + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..50))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let write_params = WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }; + let mut dataset = Dataset::write(reader, &test_uri, Some(write_params)) + .await + .unwrap(); + + // Create a scalar BTree index on i so all current fragments are indexed. + dataset + .create_index( + &["i"], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let original_fragments = dataset.fragments().clone(); + std::assert_eq!(original_fragments.len(), 5); + + // Filter i != 30 is implemented as NOT(i = 30). The scalar index evaluates the + // equality as an exact allow-list and then negates it to an exact block-list + // containing only the single row with i = 30. Since no fragment is fully + // blocked in the resulting RowAddrMask::BlockList, scalar_indexed_prune_fragments + // must keep all fragments and preserve their manifest order. + let pruned = dataset.prune_fragments("i != 30").await.unwrap(); + + std::assert_eq!(pruned.len(), original_fragments.len()); + let original_ids: Vec = original_fragments.iter().map(|f| f.id).collect(); + let pruned_ids: Vec = pruned.iter().map(|f| f.id).collect(); + std::assert_eq!(pruned_ids, original_ids); +} + +#[tokio::test] +async fn test_prune_fragments_with_scalar_index_blocklist_empty_keeps_all_fragments() { + // Dataset with 5 fragments of 10 rows each: i = [0, 1, ..., 49]. + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..50))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let write_params = WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }; + let mut dataset = Dataset::write(reader, &test_uri, Some(write_params)) + .await + .unwrap(); + + // Create a scalar BTree index on i so all current fragments are indexed. + dataset + .create_index( + &["i"], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let original_fragments = dataset.fragments().clone(); + std::assert_eq!(original_fragments.len(), 5); + + // Filter i != 1000 is implemented as NOT(i = 1000). The equality matches no + // rows, so the negated scalar index result is an exact block-list with an + // empty RowAddrTreeMap. scalar_indexed_prune_fragments treats an empty + // RowAddrMask::BlockList as "no fragment is blocked" and returns all + // fragments unchanged. + let pruned = dataset.prune_fragments("i != 1000").await.unwrap(); + + std::assert_eq!(pruned.len(), original_fragments.len()); + let original_ids: Vec = original_fragments.iter().map(|f| f.id).collect(); + let pruned_ids: Vec = pruned.iter().map(|f| f.id).collect(); + std::assert_eq!(pruned_ids, original_ids); +}