diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7b19ad72c8..e306957133 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -5189,6 +5189,34 @@ def analyze_plan(self) -> str: return self._scanner.analyze_plan() + def plan_splits( + self, max_split_size_bytes: Optional[int] = None + ) -> List[List["FragmentMetadata"]]: + """Plan splits for distributed scanning. + + This method analyzes the scanner's filter and uses indices to determine + which fragments need to be scanned and approximately how many rows each + fragment will return. It then groups fragments into splits that can be + processed independently. + + The scanner estimates the size of each row based on the output schema + projection and uses that to determine how many rows fit within the + target split size. + + Parameters + ---------- + max_split_size_bytes : int, optional + The target maximum size in bytes for each split. Defaults to 128MB. + + Returns + ------- + List[List[FragmentMetadata]] + A list of splits, where each split is a list of FragmentMetadata objects. + Each split can be processed independently for distributed scanning. + """ + + return self._scanner.plan_splits(max_split_size_bytes=max_split_size_bytes) + class DatasetOptimizer: def __init__(self, dataset: LanceDataset): diff --git a/python/src/scanner.rs b/python/src/scanner.rs index 1e85af6711..bf49106dc4 100644 --- a/python/src/scanner.rs +++ b/python/src/scanner.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::pyarrow::*; use arrow_array::RecordBatchReader; -use lance::dataset::scanner::ExecutionSummaryCounts; +use lance::dataset::scanner::{ExecutionSummaryCounts, SplitPackStrategy}; use pyo3::prelude::*; use pyo3::pyclass; @@ -30,6 +30,7 @@ use pyo3::exceptions::PyValueError; use crate::reader::LanceReader; use crate::rt; use crate::schema::logical_arrow_schema; +use crate::utils::PyLance; /// This will be wrapped by a python class to provide /// additional functionality @@ -150,4 +151,29 @@ impl Scanner { Ok(PyArrowType(Box::new(reader))) } + + #[pyo3(signature = (max_split_size_bytes=None))] + fn plan_splits<'py>( + self_: PyRef<'py, Self>, + max_split_size_bytes: Option, + ) -> PyResult>>> { + let scanner = self_.scanner.clone(); + let strategy = max_split_size_bytes.map(SplitPackStrategy::MaxSizeBytes); + let splits = rt() + .spawn(Some(self_.py()), async move { + scanner.plan_splits(strategy).await + })? + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + splits + .into_iter() + .map(|split| { + split + .fragments + .into_iter() + .map(|sf| PyLance(sf.fragment).into_pyobject(self_.py())) + .collect::, _>>() + }) + .collect::, _>>() + } } diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index a1f56d48a8..0c21027b38 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -657,6 +657,10 @@ impl RowAddrTreeMap { }), }) } + + pub fn fragments(&self) -> Vec { + self.inner.keys().cloned().collect() + } } impl std::ops::BitOr for RowAddrTreeMap { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index a0812d6caf..866f0f774a 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::{HashMap, HashSet}; use std::ops::Range; use std::pin::Pin; use std::sync::{Arc, LazyLock}; @@ -600,6 +601,37 @@ pub struct Scanner { autoproject_scoring_columns: bool, } +/// Represents a split for parallel scanning of fragments +/// +/// A split contains one or more fragments that can be scanned together. +/// Splits can be used to distribute scanning work across multiple workers or threads. +#[derive(Debug, Clone)] +pub struct Split { + pub fragments: Vec, +} + +/// A fragment within a [`Split`], along with metadata about the expected +/// number of rows that will be scanned from it. +#[derive(Debug, Clone)] +pub struct SplitFragment { + /// The fragment to scan. + pub fragment: Fragment, + /// An upper bound on the number of rows that will be read from this + /// fragment after applying any filters or index pruning. + pub max_row_count: usize, +} + +/// Strategy for packing fragments into splits. +#[derive(Debug, Clone)] +pub enum SplitPackStrategy { + /// Target a maximum size in bytes per split. The scanner estimates the row + /// size from the output schema and calculates how many rows fit within this + /// budget. + MaxSizeBytes(usize), + /// Target a maximum number of rows per split. + MaxRowCount(usize), +} + /// Represents a user-requested take operation #[derive(Debug, Clone)] pub enum TakeOperation { @@ -858,6 +890,214 @@ impl Scanner { self.fragments.is_some() } + /// Plan splits for distributed or parallel scanning of the dataset. + /// + /// This method analyzes the fragments to be scanned and groups them into [`Split`]s + /// that can be processed independently by multiple workers or threads. It uses a + /// bin-packing algorithm to create balanced splits based on estimated row counts. + pub async fn plan_splits(&self, strategy: Option) -> Result> { + // Collect initial set of fragments to scan + let fragments = if let Some(fragments) = self.fragments.as_ref() { + Arc::new(fragments.clone()) + } else { + Arc::new(self.dataset.fragments().as_ref().clone()) + }; + + // Use indices to prune fragments and compute max row counts per fragment + let mut frag_max_row_counts: HashMap = HashMap::new(); + let mut covered_frag_ids: HashSet<_> = HashSet::new(); + + let filter_plan = self.create_filter_plan(true).await?; + if let Some(index_expr) = filter_plan.expr_filter_plan.index_query.as_ref() { + // Partition fragments by coverage of the index expression + let (covered_frags, _) = self + .partition_frags_by_coverage(index_expr, fragments.clone()) + .await?; + covered_frags.iter().for_each(|frag| { + covered_frag_ids.insert(frag.id); + }); + + // Evaluate the index expression to retrieve a bitmask of matching rows + let expr_result = index_expr + .evaluate(self.dataset.as_ref(), &NoOpMetricsCollector) + .await?; + match expr_result { + IndexExprResult::Exact(mask) | IndexExprResult::AtMost(mask) => { + match mask { + RowAddrMask::AllowList(bitmap) => { + // Iterate over covered fragments and update row counts + let allow_frag_ids: HashSet = + bitmap.fragments().into_iter().collect(); + + for frag in &covered_frags { + if allow_frag_ids.contains(&(frag.id as u32)) { + let row_count = match bitmap.get_fragment_bitmap(frag.id as u32) + { + Some(frag_bitmap) => frag_bitmap.len() as usize, + None => { + // Since we know `frag.id` is in the bitmap, this `None + // means the fragment bitmap is full. Use the total + // number of rows in the fragment. + frag.num_rows().unwrap_or(usize::MAX) + } + }; + + frag_max_row_counts.insert(frag.id, row_count); + } else { + // PRUNE fragment since no rows match + } + } + } + RowAddrMask::BlockList(bitmap) => { + // Iterate over covered fragments and update row counts + let blocked_frag_ids: HashSet = + bitmap.fragments().into_iter().collect(); + + for frag in &covered_frags { + if !blocked_frag_ids.contains(&(frag.id as u32)) { + // Fragment is not blocked, so all rows are allowed + frag_max_row_counts + .insert(frag.id, frag.num_rows().unwrap_or(usize::MAX)); + } else { + match bitmap.get_fragment_bitmap(frag.id as u32) { + Some(frag_bitmap) => { + let blocked_row_count = frag_bitmap.len() as usize; + let row_count = match frag.num_rows() { + Some(row_count) => row_count - blocked_row_count, + None => usize::MAX, + }; + frag_max_row_counts.insert(frag.id, row_count); + } + None => { + // PRUNE fragment since no rows match + // Since we know `frag.id` is in the bitmap, this `None + // means the fragment bitmap is full. + } + } + } + } + } + } + } + IndexExprResult::AtLeast(_) => { + // In the `AtLeast` case some of the rows in the block list may be false + // positives. Therefore, we can not prune any fragments as we can not guarantee + // any fragments will not have matching rows. + } + } + } + + // Estimate row counts for fragments not covered by indices. + fragments + .iter() + .filter(|frag| !covered_frag_ids.contains(&frag.id)) + .for_each(|frag| { + // Estimate the number of rows in the fragment that satisfy the filter + let max_row_count = match frag.num_rows() { + Some(count) => count, + None => usize::MAX, + }; + frag_max_row_counts.insert(frag.id, max_row_count); + }); + + // Bin pack fragments into splits for parallel processing + const DEFAULT_SPLIT_SIZE: usize = 128 * 1024 * 1024; + const DEFAULT_VARIABLE_FIELD_SIZE: usize = 64; + + let max_rows_per_split = match strategy { + Some(SplitPackStrategy::MaxRowCount(max_row_count)) => max_row_count, + Some(SplitPackStrategy::MaxSizeBytes(max_bytes)) => { + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + max_bytes / estimated_row_size.max(1) + } + None => { + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + DEFAULT_SPLIT_SIZE / estimated_row_size.max(1) + } + }; + + let bins = Self::bin_pack(&frag_max_row_counts, max_rows_per_split); + + // Convert bins to splits + let fragment_map: HashMap = fragments.iter().map(|f| (f.id, f)).collect(); + let splits = bins + .into_iter() + .map(|bin| { + let frags = bin + .into_iter() + .filter_map(|id| { + fragment_map.get(&id).map(|&f| SplitFragment { + fragment: f.clone(), + max_row_count: frag_max_row_counts + .get(&id) + .copied() + .unwrap_or(usize::MAX), + }) + }) + .collect(); + Split { fragments: frags } + }) + .collect(); + + Ok(splits) + } + + /// Packs IDs into bins where each bin's total count is less than `maximum_count`. + /// + /// Uses a first-fit decreasing algorithm: items are sorted by count in descending + /// order, then each item is placed in the first bin that has room for it. + fn bin_pack(items: &HashMap, maximum_count: usize) -> Vec> { + // Convert to vec and sort by count descending for better packing + let mut items: Vec<(u64, usize)> = items.iter().map(|(&k, &v)| (k, v)).collect(); + items.sort_by(|a, b| b.1.cmp(&a.1)); + + let mut bins: Vec<(Vec, usize)> = Vec::new(); // (ids, current_count) + + for (id, count) in items { + // Items that exceed the maximum get their own bin + if count > maximum_count { + bins.push((vec![id], count)); + continue; + } + + // Find first bin with enough remaining capacity + let mut placed = false; + for (bin_ids, bin_count) in &mut bins { + if *bin_count + count <= maximum_count { + bin_ids.push(id); + *bin_count += count; + placed = true; + break; + } + } + + // Create new bin if no existing bin has room + if !placed { + bins.push((vec![id], count)); + } + } + + bins.into_iter().map(|(ids, _)| ids).collect() + } + /// Empty Projection (useful for count queries) /// /// The row_address will be scanned (no I/O required) but not included in the output @@ -9083,4 +9323,171 @@ mod test { runtime.handle().metrics().num_alive_tasks() ); } + + #[test] + fn test_bin_pack_empty() { + let items: HashMap = HashMap::new(); + let bins = Scanner::bin_pack(&items, 100); + assert!(bins.is_empty()); + } + + #[test] + fn test_bin_pack_mixed_sizes() { + // maximum = 100 + // Items: 70, 50, 40, 20, 10 + // Sorted descending: 70, 50, 40, 20, 10 + // Bin 1: 70, try 50 (70+50=120 > 100, no), try 40 (70+40=110 > 100, no), + // try 20 (70+20=90 <= 100, yes), try 10 (90+10=100 <= 100, yes) -> [70, 20, 10] + // Bin 2: 50, try 40 (50+40=90 <= 100, yes) -> [50, 40] + let items = HashMap::from([(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)]); + let bins = Scanner::bin_pack(&items, 100); + + // Verify total items across all bins + let total_items: usize = bins.iter().map(|b| b.len()).sum(); + assert_eq!(total_items, 5); + + // Each bin's total count should be <= maximum (or a single oversized item) + let item_map: HashMap = [(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)].into(); + for bin in &bins { + let bin_total: usize = bin.iter().map(|id| item_map[id]).sum(); + assert!(bin_total <= 100, "bin total {} exceeds maximum", bin_total); + } + } + + #[test] + fn test_bin_pack_oversized_items_get_own_bin() { + let items = HashMap::from([(1, 200), (2, 150), (3, 30)]); + let bins = Scanner::bin_pack(&items, 100); + + // Oversized items (200, 150) each get their own bin + // Item 30 could fit in a new bin + assert_eq!(bins.len(), 3); + + // Each bin should have exactly one item + for bin in &bins { + assert_eq!(bin.len(), 1); + } + } + + #[tokio::test] + async fn test_plan_splits_basic() { + // Create a dataset with 4 fragments of 100 rows each, single i32 column (4 bytes) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + let splits = dataset.scan().plan_splits(None).await.unwrap(); + + // Default split size is 128MB, each fragment has 100 rows of 4 bytes = 400 bytes + // max_rows_per_split = 128*1024*1024 / 4 = 33554432 + // All 400 rows fit in one split + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 4); + } + + #[tokio::test] + async fn test_plan_splits_with_small_split_size() { + // Create 4 fragments of 100 rows, single i32 column (4 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Set split size to 200 bytes -> max_rows = 200/4 = 50 rows per split + // Each fragment has 100 rows >= 50, so each gets its own bin + let options = SplitPackStrategy::MaxSizeBytes(200); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + + assert_eq!(splits.len(), 4); + for split in &splits { + assert_eq!(split.fragments.len(), 1); + } + } + + #[tokio::test] + async fn test_plan_splits_grouping() { + // Create 4 fragments of 50 rows, single i32 column (4 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(50)) + .await + .unwrap(); + + // Set split size to 400 bytes -> max_rows = 400/4 = 100 rows per split + // Each fragment has 50 rows, so two fragments can fit per split (50+50=100, not < 100) + // Actually 50+50=100 is NOT < 100, so each fragment gets its own bin + let options = SplitPackStrategy::MaxSizeBytes(400); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 4); + + // With 404 bytes: max_rows = 404/4 = 101 rows per split + // Each fragment has 50 rows, 50+50=100 < 101, so two fragments per split + let options = SplitPackStrategy::MaxSizeBytes(404); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 2); + for split in &splits { + assert_eq!(split.fragments.len(), 2); + } + } + + #[tokio::test] + async fn test_plan_splits_with_projection() { + // Create dataset with two i32 columns (8 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("a", array::step::()) + .col("b", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Full projection: 8 bytes per row, max_rows = 800/8 = 100 + // Each fragment has 100 rows >= 100, so each gets its own bin + let options = SplitPackStrategy::MaxSizeBytes(800); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 4); + + // Single column projection: 4 bytes per row, max_rows = 800/4 = 200 + // Each fragment has 100 rows, 100+100=200 NOT < 200, so each gets its own bin + let mut scanner = dataset.scan(); + scanner.project(&["a"]).unwrap(); + let splits = scanner + .plan_splits(Some(SplitPackStrategy::MaxSizeBytes(800))) + .await + .unwrap(); + assert_eq!(splits.len(), 4); + + // Single column projection with larger budget: max_rows = 804/4 = 201 + // 100 + 100 = 200 < 201, so two fragments per split + let mut scanner = dataset.scan(); + scanner.project(&["a"]).unwrap(); + let splits = scanner + .plan_splits(Some(SplitPackStrategy::MaxSizeBytes(804))) + .await + .unwrap(); + assert_eq!(splits.len(), 2); + } + + #[tokio::test] + async fn test_plan_splits_with_fragments() { + // Create dataset with 4 fragments + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Only scan 2 specific fragments + let frags: Vec<_> = dataset.fragments()[..2].to_vec(); + + let mut scanner = dataset.scan(); + scanner.with_fragments(frags); + + // Large split size so everything fits in one split + let splits = scanner.plan_splits(None).await.unwrap(); + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 2); + } }