diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 0b96b4cb71d..5b2c4e98358 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -70,8 +70,6 @@ from .util import _target_partition_size_to_num_partitions, td_to_micros if TYPE_CHECKING: - from pyarrow._compute import Expression - from .commit import CommitLock from .progress import FragmentWriteProgress from .types import ReaderLike @@ -1000,18 +998,82 @@ def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]): """ self._ds.replace_field_metadata(field_name, new_metadata) - def get_fragments(self, filter: Optional[Expression] = None) -> List[LanceFragment]: + def get_fragments( + self, filter: Optional[Union[str, pa.compute.Expression]] = None + ) -> List[LanceFragment]: """Get all fragments from the dataset. - Note: filter is not supported yet. - """ - if filter is not None: - raise ValueError("get_fragments() does not support filter yet") - return [ + Parameters + ---------- + filter : str or pa.compute.Expression, optional + Filter expression to apply to fragments. Only fragments containing + rows that match the filter will be returned. Can be either a string + SQL expression or a PyArrow compute Expression. + + Returns + ------- + List[LanceFragment] + List of fragments that contain matching data. If no filter is provided, + returns all fragments (backward compatible behavior). + + Examples + -------- + >>> import lance + >>> import pyarrow as pa + >>> + >>> # Create a dataset + >>> data = pa.table({ + ... "age": [20, 25, 30, 35], + ... "name": ["Alice", "Bob", "Charlie", "David"]}, + ... schema=pa.schema([ + ... pa.field("age", pa.int32()), + ... pa.field("name", pa.string()) + ... ])) + >>> dataset = lance.write_dataset(data, "my_dataset") + >>> + >>> # Get all fragments (backward compatible) + >>> all_fragments = dataset.get_fragments() + >>> + >>> # Get fragments with string filter + >>> filtered_fragments = dataset.get_fragments(filter="age > 25") + >>> + >>> # Get fragments with PyArrow Expression filter + >>> expr = pa.compute.greater(pa.compute.field("age"), pa.scalar(25)) + >>> filtered_fragments = dataset.get_fragments(filter=expr) + + Notes + ----- + This method uses the fragment's count_rows method to determine if a + fragment contains any rows matching the filter. Fragments with zero + matching rows are excluded from the result. + + The filtering is performed at the fragment level, which means: + - Only fragments containing at least one matching row are returned + - The actual filtering of rows within fragments is not performed here + - Use fragment.scanner(filter=filter) to get filtered data from + - individual fragments + """ + # Get all fragments first (same as original implementation) + all_fragments = [ LanceFragment(self, fragment_id=None, fragment=f) for f in self._ds.get_fragments() ] + if filter is None: + return all_fragments + + # Filter fragments that contain matching rows + try: + filtered_fragments = [ + fragment + for fragment in all_fragments + if fragment.count_rows(filter) > 0 + ] + except Exception as e: + raise ValueError("Error counting rows in fragments while filtering.") from e + + return filtered_fragments + def get_fragment(self, fragment_id: int) -> Optional[LanceFragment]: """Get the fragment with fragment id.""" raw_fragment = self._ds.get_fragment(fragment_id) @@ -1299,7 +1361,6 @@ def alter_columns(self, *alterations: Iterable[AlterColumn]): ---------- alterations : Iterable[Dict[str, Any]] A sequence of dictionaries, each with the following keys: - - "path": str The column path to alter. For a top-level column, this is the name. For a nested column, this is the dot-separated path, e.g. "a.b.c". diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 86963261435..dabb20beea6 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -4116,3 +4116,70 @@ def test_diff_meta(tmp_path: Path): # Now try to diff with the cleaned up version 1 (should raise error) with pytest.raises(ValueError): dataset.diff_meta(1) + + +def test_get_fragments_filter_fragment_reduction(tmp_path: Path): + # Create a dataset with multiple fragments + # where some fragments won't match the filter + data = { + "age": list(range(1, 101)), # Ages 1-100 + "department": ["Engineering"] * 30 + ["Sales"] * 30 + ["Market"] * 40, + "salary": [50000 + i * 500 for i in range(100)], + } + table = pa.table(data) + # Force multiple fragments with small max_rows_per_file + dataset = lance.write_dataset(table, tmp_path, max_rows_per_file=15) + + all_fragments = dataset.get_fragments() + assert len(all_fragments) > 3, ( + "Should have multiple fragments for effective testing" + ) + + # Apply a selective filter that should exclude some fragments + # This filter should only match data in some fragments, not all + selective_filter = "department = 'Engineering' AND age <= 20" + filtered_fragments = dataset.get_fragments(filter=selective_filter) + + # CORE ASSERTION: Filter should reduce fragment count + assert len(filtered_fragments) < len(all_fragments), ( + f"Filter should reduce fragments: " + f"{len(filtered_fragments)} < {len(all_fragments)}" + ) + + # Verify the filtered fragments contain only relevant data + total_matching_rows = sum( + fragment.count_rows(selective_filter) for fragment in filtered_fragments + ) + expected_rows = dataset.count_rows(selective_filter) + assert total_matching_rows == expected_rows + assert expected_rows > 0, "Filter should match some data" + + +def test_get_fragments_filter_functionality(tmp_path: Path): + data = { + "age": list(range(20, 60, 5)), # [20, 25, 30, 35, 40, 45, 50, 55] + "department": ["Engineering", "Sales"] * 4, + } + table = pa.table(data) + dataset = lance.write_dataset(table, tmp_path, max_rows_per_file=3) + + all_fragments = dataset.get_fragments() + + # Test backward compatibility: filter=None should return all fragments + all_fragments_none = dataset.get_fragments(filter=None) + assert len(all_fragments_none) == len(all_fragments) + + # Test string filter + string_filtered = dataset.get_fragments(filter="age > 35") + assert len(string_filtered) <= len(all_fragments) + string_count = sum(fragment.count_rows("age > 35") for fragment in string_filtered) + + # Test PyArrow expression filter (same condition) + expr = pc.greater(pc.field("age"), pa.scalar(35)) + expr_filtered = dataset.get_fragments(filter=expr) + assert len(expr_filtered) <= len(all_fragments) + expr_count = sum(fragment.count_rows(expr) for fragment in expr_filtered) + + # Both should return the same number of matching rows + assert string_count == expr_count + assert string_count > 0 # Should match some data