Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 70 additions & 9 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@
from .util import _target_partition_size_to_num_partitions, td_to_micros

if TYPE_CHECKING:
from pyarrow._compute import Expression

from .commit import CommitLock
from .progress import FragmentWriteProgress
from .types import ReaderLike
Expand Down Expand Up @@ -1000,18 +998,82 @@ def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]):
"""
self._ds.replace_field_metadata(field_name, new_metadata)

def get_fragments(self, filter: Optional[Expression] = None) -> List[LanceFragment]:
def get_fragments(
self, filter: Optional[Union[str, pa.compute.Expression]] = None
) -> List[LanceFragment]:
"""Get all fragments from the dataset.

Note: filter is not supported yet.
"""
if filter is not None:
raise ValueError("get_fragments() does not support filter yet")
return [
Parameters
----------
filter : str or pa.compute.Expression, optional
Filter expression to apply to fragments. Only fragments containing
rows that match the filter will be returned. Can be either a string
SQL expression or a PyArrow compute Expression.

Returns
-------
List[LanceFragment]
List of fragments that contain matching data. If no filter is provided,
returns all fragments (backward compatible behavior).

Examples
--------
>>> import lance
>>> import pyarrow as pa
>>>
>>> # Create a dataset
>>> data = pa.table({
... "age": [20, 25, 30, 35],
... "name": ["Alice", "Bob", "Charlie", "David"]},
... schema=pa.schema([
... pa.field("age", pa.int32()),
... pa.field("name", pa.string())
... ]))
>>> dataset = lance.write_dataset(data, "my_dataset")
>>>
>>> # Get all fragments (backward compatible)
>>> all_fragments = dataset.get_fragments()
>>>
>>> # Get fragments with string filter
>>> filtered_fragments = dataset.get_fragments(filter="age > 25")
>>>
>>> # Get fragments with PyArrow Expression filter
>>> expr = pa.compute.greater(pa.compute.field("age"), pa.scalar(25))
>>> filtered_fragments = dataset.get_fragments(filter=expr)

Notes
-----
This method uses the fragment's count_rows method to determine if a
fragment contains any rows matching the filter. Fragments with zero
matching rows are excluded from the result.

The filtering is performed at the fragment level, which means:
- Only fragments containing at least one matching row are returned
- The actual filtering of rows within fragments is not performed here
- Use fragment.scanner(filter=filter) to get filtered data from
- individual fragments
"""
# Get all fragments first (same as original implementation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand, this means you are not really improving from perf perspective? And it would be quite expensive to check every fragment against the filter. We have added zone map now, I think we should select fragments (zones even better) to distribute based on the zone map index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackye1995 Actually, this is a tradeoff. When the number of fragments is relatively large, it will bring some problems. For example, when the number of fragments in a dataset exceeds 20,000, and because the current dataset caches the manifest by default, this causes significant consumption each time physical plan tasks are distributed.

Therefore, it is hoped here to first filter the valid fragments, meanwhile the filter here can still be left un filled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I am all for adding this filter. To be more specific, I am saying:

filtered_fragments = [
                fragment
                for fragment in all_fragments
                if fragment.count_rows(filter) > 0
            ]

seems expensive since you are running a count_rows per fragment.

I think we can expose a corse-grained plan_fragments API that plans the fragments to distribute based on a filter, and we can apply that filter once against zone map or bloom filter index.

Copy link
Contributor Author

@Jay-ju Jay-ju Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackye1995 What you mean here is that after the dataset creates a zonemap index, dataset.scanner can use the index. But fragment.scanner cannot use the index? So we need to make fragment also use the index? Is my understanding correct?

Copy link
Contributor Author

@Jay-ju Jay-ju Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fragment.scanner should also be able to use ScalarIndexQuery, but after my test, it's true that the time taken by fragment's count_rows is much worse than that of dataset's count_rows. I guess it should be due to multi-threaded operations.
image

Copy link
Contributor

@jackye1995 jackye1995 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to align the goal here first. In my understanding, we want to use get_fragments with a filter because of use cases like https://github.com/lancedb/lance-ray/blame/main/lance_ray/io.py#L265.

So what happens is that you (1) get the list of fragments to distribute, and then (2) process each fragment in worker.

The step 1 here you don't necessarily need an exact match, you just want to prune down the list of fragments that might match the filter. In many cases this can already discard a large portion of the fragments. And in the actual worker, you can do the actual filtering.

That is why I think using inexact indexes like zone_map index and bloom filter index is ideal here, because we would be able to very quickly answer the question of "what fragments might match the filter" and then distribute the work. Then in each worker, you can pushdown the filter and do the exact scan filtering there. If there are exact indexes like fbtree or bitmap we can still use them, just they could be slower but still okay.
If there is no index exist against the filtered columns, we can just return the full list of fragments which means all fragments might match the filter.

Compared to that, what is done in the current code is very heavy weight. You are doing an exact count which actually triggers a scan for each fragment. You are also doing that for each fragment separately which is even worse. I don't think that really achieves the goal of "reduce empty task overhead in distributed engines (Ray/Daft/Spark)" because you are doing much more compute in the coordinator.

I guess maybe from implementation perspective, get_fragments should do what you have implemented. But for the purpose of reduce empty task overhead, we might want to create another method like prune_fragments.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the goal can be agreed upon. My scenario is indeed that the driver uses a coarse-grained index, and the worker specifically executes the filtering. However, after looking at the current implementation of this part of the index, there are some minor issues that may need to be clarified first:

This means that if we want to filter a specific index for use, we need to add an interface similar to

index_selection = {
    "column_hints": {
        "col1": ["btree"],
        "col2": ["bloom_filter", "btree"]
    },
}

scanner = dataset.scanner(index_selection=index_selection)
  • Is it necessary to support multiple indexes for the one column here?

  • Suppose that the index of this column has been obtained, and then after applying the filter to the index, a RowIdTreeMap is obtained. Then, I wonder if it is necessary to add an approx_count_rows() interface to count the number of row_id in the RowIdTreeMap. Additionally, can this interface also be exposed in the fragment interface?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah today in general there is an implicit assumption of 1 index per column. With these inexact indexes, this is no longer true since we might want to have a zone-map or bloom filter for many columns, and then have exact indexes to use at worker. I created #4805 listing a few work we need to do.

For here, I think we can first assume 1 index per column for now so we are not blocked by it.

Copy link
Contributor

@fangbo fangbo Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to align the goal here first. In my understanding, we want to use get_fragments with a filter because of use cases like https://github.com/lancedb/lance-ray/blame/main/lance_ray/io.py#L265.

So what happens is that you (1) get the list of fragments to distribute, and then (2) process each fragment in worker.

The step 1 here you don't necessarily need an exact match, you just want to prune down the list of fragments that might match the filter. In many cases this can already discard a large portion of the fragments. And in the actual worker, you can do the actual filtering.

That is why I think using inexact indexes like zone_map index and bloom filter index is ideal here, because we would be able to very quickly answer the question of "what fragments might match the filter" and then distribute the work. Then in each worker, you can pushdown the filter and do the exact scan filtering there. If there are exact indexes like fbtree or bitmap we can still use them, just they could be slower but still okay. If there is no index exist against the filtered columns, we can just return the full list of fragments which means all fragments might match the filter.

Compared to that, what is done in the current code is very heavy weight. You are doing an exact count which actually triggers a scan for each fragment. You are also doing that for each fragment separately which is even worse. I don't think that really achieves the goal of "reduce empty task overhead in distributed engines (Ray/Daft/Spark)" because you are doing much more compute in the coordinator.

I guess maybe from implementation perspective, get_fragments should do what you have implemented. But for the purpose of reduce empty task overhead, we might want to create another method like prune_fragments.

@jackye1995 I submitted a PR #5625 to prune fragments using filter and index. Could you please take a look and let me know if this seems reasonable? Thank you.

cc @majin1102

all_fragments = [
LanceFragment(self, fragment_id=None, fragment=f)
for f in self._ds.get_fragments()
]

if filter is None:
return all_fragments

# Filter fragments that contain matching rows
try:
filtered_fragments = [
fragment
for fragment in all_fragments
if fragment.count_rows(filter) > 0
]
except Exception as e:
raise ValueError("Error counting rows in fragments while filtering.") from e

return filtered_fragments

def get_fragment(self, fragment_id: int) -> Optional[LanceFragment]:
"""Get the fragment with fragment id."""
raw_fragment = self._ds.get_fragment(fragment_id)
Expand Down Expand Up @@ -1299,7 +1361,6 @@ def alter_columns(self, *alterations: Iterable[AlterColumn]):
----------
alterations : Iterable[Dict[str, Any]]
A sequence of dictionaries, each with the following keys:

- "path": str
The column path to alter. For a top-level column, this is the name.
For a nested column, this is the dot-separated path, e.g. "a.b.c".
Expand Down
67 changes: 67 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4116,3 +4116,70 @@ def test_diff_meta(tmp_path: Path):
# Now try to diff with the cleaned up version 1 (should raise error)
with pytest.raises(ValueError):
dataset.diff_meta(1)


def test_get_fragments_filter_fragment_reduction(tmp_path: Path):
# Create a dataset with multiple fragments
# where some fragments won't match the filter
data = {
"age": list(range(1, 101)), # Ages 1-100
"department": ["Engineering"] * 30 + ["Sales"] * 30 + ["Market"] * 40,
"salary": [50000 + i * 500 for i in range(100)],
}
table = pa.table(data)
# Force multiple fragments with small max_rows_per_file
dataset = lance.write_dataset(table, tmp_path, max_rows_per_file=15)

all_fragments = dataset.get_fragments()
assert len(all_fragments) > 3, (
"Should have multiple fragments for effective testing"
)

# Apply a selective filter that should exclude some fragments
# This filter should only match data in some fragments, not all
selective_filter = "department = 'Engineering' AND age <= 20"
filtered_fragments = dataset.get_fragments(filter=selective_filter)

# CORE ASSERTION: Filter should reduce fragment count
assert len(filtered_fragments) < len(all_fragments), (
f"Filter should reduce fragments: "
f"{len(filtered_fragments)} < {len(all_fragments)}"
)

# Verify the filtered fragments contain only relevant data
total_matching_rows = sum(
fragment.count_rows(selective_filter) for fragment in filtered_fragments
)
expected_rows = dataset.count_rows(selective_filter)
assert total_matching_rows == expected_rows
assert expected_rows > 0, "Filter should match some data"


def test_get_fragments_filter_functionality(tmp_path: Path):
data = {
"age": list(range(20, 60, 5)), # [20, 25, 30, 35, 40, 45, 50, 55]
"department": ["Engineering", "Sales"] * 4,
}
table = pa.table(data)
dataset = lance.write_dataset(table, tmp_path, max_rows_per_file=3)

all_fragments = dataset.get_fragments()

# Test backward compatibility: filter=None should return all fragments
all_fragments_none = dataset.get_fragments(filter=None)
assert len(all_fragments_none) == len(all_fragments)

# Test string filter
string_filtered = dataset.get_fragments(filter="age > 35")
assert len(string_filtered) <= len(all_fragments)
string_count = sum(fragment.count_rows("age > 35") for fragment in string_filtered)

# Test PyArrow expression filter (same condition)
expr = pc.greater(pc.field("age"), pa.scalar(35))
expr_filtered = dataset.get_fragments(filter=expr)
assert len(expr_filtered) <= len(all_fragments)
expr_count = sum(fragment.count_rows(expr) for fragment in expr_filtered)

# Both should return the same number of matching rows
assert string_count == expr_count
assert string_count > 0 # Should match some data
Loading