Skip to content

Coalesce and parallelize partial shard reads #3004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 113 additions & 20 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
from zarr.core.common import (
ChunkCoords,
ChunkCoordsLike,
concurrent_map,
parse_enum,
parse_named_configuration,
parse_shapelike,
product,
)
from zarr.core.config import config
from zarr.core.indexing import (
BasicIndexer,
SelectorTuple,
Expand Down Expand Up @@ -327,6 +329,11 @@
return await shard_builder.finalize(index_location, index_encoder)


class _ChunkCoordsByteSlice(NamedTuple):
coords: ChunkCoords
byte_slice: slice


@dataclass(frozen=True)
class ShardingCodec(
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
Expand Down Expand Up @@ -490,32 +497,21 @@
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}

# reading bytes of all requested chunks
shard_dict: ShardMapping = {}
shard_dict_maybe: ShardMapping | None = {}
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
# read entire shard
shard_dict_maybe = await self._load_full_shard_maybe(
byte_getter=byte_getter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
byte_getter, chunk_spec.prototype, chunks_per_shard
)
if shard_dict_maybe is None:
return None
shard_dict = shard_dict_maybe
else:
# read some chunks within the shard
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
if shard_index is None:
return None
shard_dict = {}
for chunk_coords in all_chunk_coords:
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
chunk_bytes = await byte_getter.get(
prototype=chunk_spec.prototype,
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
)
if chunk_bytes:
shard_dict[chunk_coords] = chunk_bytes
shard_dict_maybe = await self._load_partial_shard_maybe(
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords
)

if shard_dict_maybe is None:
return None
shard_dict = shard_dict_maybe

# decoding chunks and writing them into the output buffer
await self.codec_pipeline.read(
Expand All @@ -537,6 +533,103 @@
else:
return out

async def _load_partial_shard_maybe(
self,
byte_getter: ByteGetter,
prototype: BufferPrototype,
chunks_per_shard: ChunkCoords,
all_chunk_coords: set[ChunkCoords],
) -> ShardMapping | None:
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
if shard_index is None:
return None

chunks = [
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
for chunk_coords in all_chunk_coords
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
]
if len(chunks) == 0:
return {}

groups = self._coalesce_chunks(chunks)

shard_dicts = await concurrent_map(
[(group, byte_getter, prototype) for group in groups],
self._get_group_bytes,
config.get("async.concurrency"),
)

shard_dict: ShardMutableMapping = {}
for d in shard_dicts:
shard_dict.update(d)

return shard_dict

def _coalesce_chunks(
self,
chunks: list[_ChunkCoordsByteSlice],
) -> list[list[_ChunkCoordsByteSlice]]:
"""
Combine chunks from a single shard into groups that should be read together
in a single request.

Respects the following configuration options:
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
chunks to coalesce into a single group.
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
"""
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")

sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)

groups = []
current_group = [sorted_chunks[0]]

for chunk in sorted_chunks[1:]:
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
current_group_size = (
current_group[-1].byte_slice.stop - current_group[0].byte_slice.start
)
if gap_to_chunk < max_gap_bytes and current_group_size < coalesce_max_bytes:
current_group.append(chunk)
else:
groups.append(current_group)
current_group = [chunk]

Check warning on line 599 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L598-L599

Added lines #L598 - L599 were not covered by tests

groups.append(current_group)

return groups

async def _get_group_bytes(
self,
group: list[_ChunkCoordsByteSlice],
byte_getter: ByteGetter,
prototype: BufferPrototype,
) -> ShardMapping:
group_start = group[0].byte_slice.start
group_end = group[-1].byte_slice.stop

# A single call to retrieve the bytes for the entire group.
group_bytes = await byte_getter.get(
prototype=prototype,
byte_range=RangeByteRequest(group_start, group_end),
)
if group_bytes is None:
return {}

Check warning on line 620 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L620

Added line #L620 was not covered by tests

# Extract the bytes corresponding to each chunk in group from group_bytes.
shard_dict = {}
for chunk in group:
chunk_slice = slice(
chunk.byte_slice.start - group_start,
chunk.byte_slice.stop - group_start,
)
shard_dict[chunk.coords] = group_bytes[chunk_slice]

return shard_dict

async def _encode_single(
self,
shard_array: NDBuffer,
Expand Down
6 changes: 6 additions & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def enable_gpu(self) -> ConfigSet:
},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
"sharding": {
"read": {
"coalesce_max_bytes": 100 * 2**20, # 100MiB
"coalesce_max_gap_bytes": 2**20, # 1MiB
}
},
"json_indent": 2,
"codec_pipeline": {
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",
Expand Down
143 changes: 142 additions & 1 deletion tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
from typing import Any
from unittest.mock import AsyncMock

import numpy as np
import numpy.typing as npt
Expand All @@ -9,7 +10,7 @@
import zarr.api
import zarr.api.asynchronous
from zarr import Array
from zarr.abc.store import Store
from zarr.abc.store import RangeByteRequest, Store, SuffixByteRequest
from zarr.codecs import (
BloscCodec,
ShardingCodec,
Expand Down Expand Up @@ -197,6 +198,146 @@ def test_sharding_partial_read(
assert np.all(read_data == 1)


@pytest.mark.skip("This is profiling rather than a test")
@pytest.mark.slow_hypothesis
@pytest.mark.parametrize("store", ["local"], indirect=["store"])
def test_partial_shard_read_performance(store: Store) -> None:
import asyncio
import json
from functools import partial
from itertools import product
from timeit import timeit
from unittest.mock import AsyncMock

# The whole test array is a single shard to keep runtime manageable while
# using a realistic shard size (256 MiB uncompressed, ~115 MiB compressed).
# In practice, the array is likely to be much larger with many shards of this
# rough order of magnitude. There are 512 chunks per shard in this example.
array_shape = (512, 512, 512)
shard_shape = (512, 512, 512) # 256 MiB uncompressed unit16s
chunk_shape = (64, 64, 64) # 512 KiB uncompressed unit16s
dtype = np.uint16

a = zarr.create_array(
StorePath(store),
shape=array_shape,
chunks=chunk_shape,
shards=shard_shape,
compressors=BloscCodec(cname="zstd"),
dtype=dtype,
fill_value=np.iinfo(dtype).max,
)
# Narrow range of values lets zstd compress to about 1/2 of uncompressed size
a[:] = np.random.default_rng(123).integers(low=0, high=50, size=array_shape, dtype=dtype)

num_calls = 20
experiments = []
for concurrency, get_latency, coalesce_max_gap, statement in product(
[1, 10, 100],
[0.0, 0.01],
[-1, 2**20, 10 * 2**20],
["a[0, :, :]", "a[:, 0, :]", "a[:, :, 0]"],
):
zarr.config.set(
{
"async.concurrency": concurrency,
"sharding.read.coalesce_max_gap_bytes": coalesce_max_gap,
}
)

async def get_with_latency(*args: Any, get_latency: float, **kwargs: Any) -> Any:
await asyncio.sleep(get_latency)
return await store.get(*args, **kwargs)

store_mock = AsyncMock(wraps=store, spec=store.__class__)
store_mock.get.side_effect = partial(get_with_latency, get_latency=get_latency)

a = zarr.open_array(StorePath(store_mock))

store_mock.reset_mock()

# Each timeit call accesses a 512x512 slice covering 64 chunks
time = timeit(statement, number=num_calls, globals={"a": a}) / num_calls
experiments.append(
{
"concurrency": concurrency,
"coalesce_max_gap": coalesce_max_gap,
"get_latency": get_latency,
"statement": statement,
"time": time,
"store_get_calls": store_mock.get.call_count,
}
)

with open("zarr-python-partial-shard-read-performance-with-coalesce.json", "w") as f:
json.dump(experiments, f)


@pytest.mark.parametrize("index_location", ["start", "end"])
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("coalesce_reads", [True, False])
def test_sharding_multiple_chunks_partial_shard_read(
store: Store, index_location: ShardingCodecIndexLocation, coalesce_reads: bool
) -> None:
array_shape = (16, 64)
shard_shape = (8, 32)
chunk_shape = (2, 4)
data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape)

if coalesce_reads:
# 1MiB, enough to coalesce all chunks within a shard in this example
zarr.config.set({"sharding.read.coalesce_max_gap_bytes": 2**20})
else:
zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) # disable coalescing

store_mock = AsyncMock(wraps=store, spec=store.__class__)
a = zarr.create_array(
StorePath(store_mock),
shape=data.shape,
chunks=chunk_shape,
shards={"shape": shard_shape, "index_location": index_location},
compressors=BloscCodec(cname="lz4"),
dtype=data.dtype,
fill_value=1,
)
a[:] = data

store_mock.reset_mock() # ignore store calls during array creation

# Reads 3 (2 full, 1 partial) chunks each from 2 shards (a subset of both shards)
# for a total of 6 chunks accessed
assert np.allclose(a[0, 22:42], np.arange(22, 42, dtype="float32"))

if coalesce_reads:
# 2 shard index requests + 2 coalesced chunk data byte ranges (one for each shard)
assert store_mock.get.call_count == 4
else:
# 2 shard index requests + 6 chunks
assert store_mock.get.call_count == 8

for method, args, kwargs in store_mock.method_calls:
assert method == "get"
assert args[0].startswith("c/") # get from a chunk
assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest))

store_mock.reset_mock()

# Reads 4 chunks from both shards along dimension 0 for a total of 8 chunks accessed
assert np.allclose(a[:, 0], np.arange(0, data.size, array_shape[1], dtype="float32"))

if coalesce_reads:
# 2 shard index requests + 2 coalesced chunk data byte ranges (one for each shard)
assert store_mock.get.call_count == 4
else:
# 2 shard index requests + 8 chunks
assert store_mock.get.call_count == 10

for method, args, kwargs in store_mock.method_calls:
assert method == "get"
assert args[0].startswith("c/") # get from a chunk
assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest))


@pytest.mark.parametrize(
"array_fixture",
[
Expand Down
Loading