From 408288812d32df20cc0c0f8601aadcb3e35f3e1f Mon Sep 17 00:00:00 2001 From: balbasty Date: Thu, 10 Apr 2025 13:04:16 +0100 Subject: [PATCH 1/6] Sharding: do not merge new/old if old is empty + avoid sequential buffer concatenation --- src/zarr/codecs/sharding.py | 19 ++++++++-------- src/zarr/core/buffer/cpu.py | 42 ++++++++++++++++++++++++++++++++++++ src/zarr/core/buffer/gpu.py | 43 +++++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 9 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 42b1313fac..b4407e740e 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -251,7 +251,7 @@ def create_empty( if buffer_prototype is None: buffer_prototype = default_buffer_prototype() obj = cls() - obj.buf = buffer_prototype.buffer.create_zero_length() + obj.buf = buffer_prototype.buffer.Delayed.create_zero_length() obj.index = _ShardIndex.create_empty(chunks_per_shard) return obj @@ -585,15 +585,16 @@ async def _encode_partial_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - shard_dict = _MergingShardBuilder( - await self._load_full_shard_maybe( - byte_getter=byte_setter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, - ) - or _ShardReader.create_empty(chunks_per_shard), - _ShardBuilder.create_empty(chunks_per_shard), + shard_read = await self._load_full_shard_maybe( + byte_getter=byte_setter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, ) + shard_build = _ShardBuilder.create_empty(chunks_per_shard) + if shard_read: + shard_dict = _MergingShardBuilder(shard_read, shard_build) + else: + shard_dict = shard_build indexer = list( get_indexer( diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 225adb6f5c..4ac437b39b 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -185,6 +185,48 @@ def __setitem__(self, key: Any, value: Any) -> None: self._data.__setitem__(key, value) +class DelayedBuffer(Buffer): + """ + A Buffer that is the virtual concatenation of other buffers. + """ + + def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: + if array is None: + self._data_list = [] + elif isinstance(array, list): + self._data_list = list(array) + else: + self._data_list = [array] + for array in self._data_list: + if array.ndim != 1: + raise ValueError("array: only 1-dim allowed") + if array.dtype != np.dtype("b"): + raise ValueError("array: only byte dtype allowed") + + @property + def _data(self) -> npt.NDArray[Any]: + return np.concatenate(self._data_list) + + @classmethod + def from_buffer(cls, buffer: core.Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + def __add__(self, other: core.Buffer) -> Self: + if isinstance(other, self.__class__): + return self.__class__(self._data_list + other._data_list) + else: + return self.__class__(self._data_list + [other._data]) + + def __len__(self) -> int: + return sum(map(len, self._data_list)) + + +Buffer.Delayed = DelayedBuffer + + def as_numpy_array_wrapper( func: Callable[[npt.NDArray[Any]], bytes], buf: core.Buffer, prototype: core.BufferPrototype ) -> core.Buffer: diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index aac6792cff..9d4c5d5d48 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -218,6 +218,49 @@ def __setitem__(self, key: Any, value: Any) -> None: self._data.__setitem__(key, value) +class DelayedBuffer(Buffer): + """ + A Buffer that is the virtual concatenation of other buffers. + """ + + def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: + if array is None: + self._data_list = [] + elif isinstance(array, list): + self._data_list = list(array) + else: + self._data_list = [array] + for array in self._data_list: + if array.ndim != 1: + raise ValueError("array: only 1-dim allowed") + if array.dtype != np.dtype("b"): + raise ValueError("array: only byte dtype allowed") + self._data_list = list(map(cp.asarray, self._data_list)) + + @property + def _data(self) -> npt.NDArray[Any]: + return cp.concatenate(self._data_list) + + @classmethod + def from_buffer(cls, buffer: core.Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + def __add__(self, other: core.Buffer) -> Self: + if isinstance(other, self.__class__): + return self.__class__(self._data_list + other._data_list) + else: + return self.__class__(self._data_list + [other._data]) + + def __len__(self) -> int: + return sum(map(len, self._data_list)) + + +Buffer.Delayed = DelayedBuffer + + buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) register_buffer(Buffer) From 09bc21ded5f92fb27bc912a95b49c4b863ee7298 Mon Sep 17 00:00:00 2001 From: balbasty Date: Thu, 10 Apr 2025 13:04:44 +0100 Subject: [PATCH 2/6] Preliminary benchmark for sharded writes --- bench/write_shard.py | 66 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 bench/write_shard.py diff --git a/bench/write_shard.py b/bench/write_shard.py new file mode 100644 index 0000000000..99e64a0deb --- /dev/null +++ b/bench/write_shard.py @@ -0,0 +1,66 @@ +import itertools +import os.path +import shutil +import sys +import tempfile +import timeit + +import line_profiler +import numpy as np + +import zarr +import zarr.codecs +import zarr.codecs.sharding + +if __name__ == "__main__": + sys.path.insert(0, "..") + + # setup + with tempfile.TemporaryDirectory() as path: + + ndim = 3 + opt = { + 'shape': [1024]*ndim, + 'chunks': [128]*ndim, + 'shards': [512]*ndim, + 'dtype': np.float64, + } + + store = zarr.storage.LocalStore(path) + z = zarr.create_array(store, **opt) + print(z) + + def cleanup() -> None: + for elem in os.listdir(path): + elem = os.path.join(path, elem) + if not elem.endswith(".json"): + if os.path.isdir(elem): + shutil.rmtree(elem) + else: + os.remove(elem) + + def write() -> None: + wchunk = [512]*ndim + nwchunks = [n//s for n, s in zip(opt['shape'], wchunk, strict=True)] + for shard in itertools.product(*(range(n) for n in nwchunks)): + slicer = tuple( + slice(i*n, (i+1)*n) + for i, n in zip(shard, wchunk, strict=True) + ) + d = np.random.rand(*wchunk).astype(opt['dtype']) + z[slicer] = d + + print("*" * 79) + + # time + vars = {"write": write, "cleanup": cleanup, "z": z, "opt": opt} + t = timeit.repeat("write()", "cleanup()", repeat=2, number=1, globals=vars) + print(t) + print(min(t)) + print(z) + + # profile + # f = zarr.codecs.sharding.ShardingCodec._encode_partial_single + # profile = line_profiler.LineProfiler(f) + # profile.run("write()") + # profile.print_stats() From 1a9158a97f0ea5e7ce2e95b0b3f74dd34bf8d410 Mon Sep 17 00:00:00 2001 From: balbasty Date: Thu, 10 Apr 2025 15:13:23 +0100 Subject: [PATCH 3/6] Refactor DelayedBuffer + implement efficient __getitem__ --- src/zarr/core/buffer/core.py | 83 ++++++++++++++++++++++++++++++++++++ src/zarr/core/buffer/cpu.py | 37 +++------------- src/zarr/core/buffer/gpu.py | 36 ++-------------- 3 files changed, 92 insertions(+), 64 deletions(-) diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index ccab103e0f..148e8ccc1b 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -502,6 +502,89 @@ class BufferPrototype(NamedTuple): nd_buffer: type[NDBuffer] +class DelayedBuffer(Buffer): + """ + A Buffer that is the virtual concatenation of other buffers. + """ + _BufferImpl: type + _concatenate: callable + + def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: + if array is None: + self._data_list = [] + elif isinstance(array, list): + self._data_list = list(array) + else: + self._data_list = [array] + for array in self._data_list: + if array.ndim != 1: + raise ValueError("array: only 1-dim allowed") + if array.dtype != np.dtype("b"): + raise ValueError("array: only byte dtype allowed") + + @property + def _data(self) -> npt.NDArray[Any]: + return type(self)._concatenate(self._data_list) + + @classmethod + def from_buffer(cls, buffer: Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + def __add__(self, other: Buffer) -> Self: + if isinstance(other, self.__class__): + return self.__class__(self._data_list + other._data_list) + else: + return self.__class__(self._data_list + [other._data]) + + def __radd__(self, other: Buffer) -> Self: + if isinstance(other, self.__class__): + return self.__class__(other._data_list + self._data_list) + else: + return self.__class__([other._data] + self._data_list) + + def __len__(self) -> int: + return sum(map(len, self._data_list)) + + def __getitem__(self, key: slice) -> Self: + check_item_key_is_1d_contiguous(key) + start, stop = key.start, key.stop + if start is None: + start = 0 + if stop is None: + stop = len(self) + new_list = [] + offset = 0 + found_last = False + for chunk in self._data_list: + chunk_size = len(chunk) + skip = False + if offset <= start < offset + chunk_size: + # first chunk + if stop <= offset + chunk_size: + # also last chunk + chunk = chunk[start-offset:stop-offset] + found_last = True + else: + chunk = chunk[start-offset:] + elif offset <= stop <= offset + chunk_size: + # last chunk + chunk = chunk[:stop-offset] + found_last = True + elif offset + chunk_size <= start: + skip = True + + if not skip: + new_list.append(chunk) + if found_last: + break + offset += chunk_size + assert sum(map(len, new_list)) == stop - start + return self.__class__(new_list) + + # The default buffer prototype used throughout the Zarr codebase. def default_buffer_prototype() -> BufferPrototype: from zarr.registry import ( diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 4ac437b39b..32a4a08642 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -185,43 +185,16 @@ def __setitem__(self, key: Any, value: Any) -> None: self._data.__setitem__(key, value) -class DelayedBuffer(Buffer): +class DelayedBuffer(core.DelayedBuffer, Buffer): """ A Buffer that is the virtual concatenation of other buffers. """ + _BufferImpl = Buffer + _concatenate = np.concatenate def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: - if array is None: - self._data_list = [] - elif isinstance(array, list): - self._data_list = list(array) - else: - self._data_list = [array] - for array in self._data_list: - if array.ndim != 1: - raise ValueError("array: only 1-dim allowed") - if array.dtype != np.dtype("b"): - raise ValueError("array: only byte dtype allowed") - - @property - def _data(self) -> npt.NDArray[Any]: - return np.concatenate(self._data_list) - - @classmethod - def from_buffer(cls, buffer: core.Buffer) -> Self: - if isinstance(buffer, cls): - return cls(buffer._data_list) - else: - return cls(buffer._data) - - def __add__(self, other: core.Buffer) -> Self: - if isinstance(other, self.__class__): - return self.__class__(self._data_list + other._data_list) - else: - return self.__class__(self._data_list + [other._data]) - - def __len__(self) -> int: - return sum(map(len, self._data_list)) + core.DelayedBuffer.__init__(self, array) + self._data_list = list(map(np.asanyarray, self._data_list)) Buffer.Delayed = DelayedBuffer diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index 9d4c5d5d48..473f8ffdd4 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -218,45 +218,17 @@ def __setitem__(self, key: Any, value: Any) -> None: self._data.__setitem__(key, value) -class DelayedBuffer(Buffer): +class DelayedBuffer(core.DelayedBuffer, Buffer): """ A Buffer that is the virtual concatenation of other buffers. """ + _BufferImpl = Buffer + _concatenate = getattr(cp, 'concatenate', None) def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: - if array is None: - self._data_list = [] - elif isinstance(array, list): - self._data_list = list(array) - else: - self._data_list = [array] - for array in self._data_list: - if array.ndim != 1: - raise ValueError("array: only 1-dim allowed") - if array.dtype != np.dtype("b"): - raise ValueError("array: only byte dtype allowed") + core.DelayedBuffer.__init__(self, array) self._data_list = list(map(cp.asarray, self._data_list)) - @property - def _data(self) -> npt.NDArray[Any]: - return cp.concatenate(self._data_list) - - @classmethod - def from_buffer(cls, buffer: core.Buffer) -> Self: - if isinstance(buffer, cls): - return cls(buffer._data_list) - else: - return cls(buffer._data) - - def __add__(self, other: core.Buffer) -> Self: - if isinstance(other, self.__class__): - return self.__class__(self._data_list + other._data_list) - else: - return self.__class__(self._data_list + [other._data]) - - def __len__(self) -> int: - return sum(map(len, self._data_list)) - Buffer.Delayed = DelayedBuffer From 23fbc5500233b90d896d9b8a1adbb53701f328ca Mon Sep 17 00:00:00 2001 From: balbasty Date: Thu, 10 Apr 2025 15:14:04 +0100 Subject: [PATCH 4/6] Revert to original MergingShardBuilder (if/else not needed with faster buffer) --- src/zarr/codecs/sharding.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index b4407e740e..1507cac74a 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -202,7 +202,7 @@ def create_empty( buffer_prototype = default_buffer_prototype() index = _ShardIndex.create_empty(chunks_per_shard) obj = cls() - obj.buf = buffer_prototype.buffer.create_zero_length() + obj.buf = buffer_prototype.buffer.Delayed.create_zero_length() obj.index = index return obj @@ -585,16 +585,15 @@ async def _encode_partial_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - shard_read = await self._load_full_shard_maybe( - byte_getter=byte_setter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, + shard_dict = _MergingShardBuilder( + await self._load_full_shard_maybe( + byte_getter=byte_setter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, + ) + or _ShardReader.create_empty(chunks_per_shard), + _ShardBuilder.create_empty(chunks_per_shard), ) - shard_build = _ShardBuilder.create_empty(chunks_per_shard) - if shard_read: - shard_dict = _MergingShardBuilder(shard_read, shard_build) - else: - shard_dict = shard_build indexer = list( get_indexer( From 050477daada343cfc35c3ba0435041297f8536fe Mon Sep 17 00:00:00 2001 From: balbasty Date: Thu, 10 Apr 2025 15:25:35 +0100 Subject: [PATCH 5/6] Implement __setitem__ in DelayedBuffer --- src/zarr/core/buffer/core.py | 41 ++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index 148e8ccc1b..e3f5bd9864 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -584,6 +584,47 @@ def __getitem__(self, key: slice) -> Self: assert sum(map(len, new_list)) == stop - start return self.__class__(new_list) + def __setitem__(self, key: slice, value: Any) -> None: + # This assumes that `value` is a broadcasted array + check_item_key_is_1d_contiguous(key) + start, stop = key.start, key.stop + if start is None: + start = 0 + if stop is None: + stop = len(self) + new_list = [] + offset = 0 + found_last = False + value = memoryview(np.asanyarray(value)) + for chunk in self._data_list: + chunk_size = len(chunk) + skip = False + if offset <= start < offset + chunk_size: + # first chunk + if stop <= offset + chunk_size: + # also last chunk + chunk = chunk[start-offset:stop-offset] + found_last = True + else: + chunk = chunk[start-offset:] + elif offset <= stop <= offset + chunk_size: + # last chunk + chunk = chunk[:stop-offset] + found_last = True + elif offset + chunk_size <= start: + skip = True + + if not skip: + chunk[:] = value[:len(chunk)] + value = value[len(chunk):] + if len(value) == 0: + # nothing left to write + break + if found_last: + break + offset += chunk_size + return self.__class__(new_list) + # The default buffer prototype used throughout the Zarr codebase. def default_buffer_prototype() -> BufferPrototype: From 9b660733044726fe14f27ef471103c28e8c91245 Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 11 Apr 2025 12:06:07 +0100 Subject: [PATCH 6/6] Pass tests (except one) --- src/zarr/core/buffer/core.py | 45 +++++++++++++++++++++++++++--------- src/zarr/core/buffer/cpu.py | 18 +++++++++++++++ src/zarr/core/buffer/gpu.py | 18 +++++++++++++++ 3 files changed, 70 insertions(+), 11 deletions(-) diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index e3f5bd9864..9ab4cb6bf3 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -551,30 +551,44 @@ def __len__(self) -> int: def __getitem__(self, key: slice) -> Self: check_item_key_is_1d_contiguous(key) start, stop = key.start, key.stop + this_len = len(self) if start is None: start = 0 + if start < 0: + start = this_len + start if stop is None: - stop = len(self) + stop = this_len + if stop < 0: + stop = this_len + stop + if stop > this_len: + stop = this_len + if stop <= start: + return Buffer.from_buffer(b'') + new_list = [] offset = 0 found_last = False for chunk in self._data_list: chunk_size = len(chunk) skip = False - if offset <= start < offset + chunk_size: + if 0 <= start - offset < chunk_size: # first chunk - if stop <= offset + chunk_size: + if stop - offset <= chunk_size: # also last chunk chunk = chunk[start-offset:stop-offset] found_last = True else: chunk = chunk[start-offset:] - elif offset <= stop <= offset + chunk_size: + elif 0 <= stop - offset <= chunk_size: # last chunk chunk = chunk[:stop-offset] found_last = True - elif offset + chunk_size <= start: + elif chunk_size <= start - offset: + # before first chunk skip = True + else: + # middle chunk + pass if not skip: new_list.append(chunk) @@ -590,29 +604,39 @@ def __setitem__(self, key: slice, value: Any) -> None: start, stop = key.start, key.stop if start is None: start = 0 + if start < 0: + start = len(self) + start if stop is None: stop = len(self) - new_list = [] + if stop < 0: + stop = len(self) + stop + if stop <= start: + return + offset = 0 found_last = False value = memoryview(np.asanyarray(value)) for chunk in self._data_list: chunk_size = len(chunk) skip = False - if offset <= start < offset + chunk_size: + if 0 <= start - offset < chunk_size: # first chunk - if stop <= offset + chunk_size: + if stop - offset <= chunk_size: # also last chunk chunk = chunk[start-offset:stop-offset] found_last = True else: chunk = chunk[start-offset:] - elif offset <= stop <= offset + chunk_size: + elif 0 <= stop - offset <= chunk_size: # last chunk chunk = chunk[:stop-offset] found_last = True - elif offset + chunk_size <= start: + elif chunk_size <= start - offset: + # before first chunk skip = True + else: + # middle chunk + pass if not skip: chunk[:] = value[:len(chunk)] @@ -623,7 +647,6 @@ def __setitem__(self, key: slice, value: Any) -> None: if found_last: break offset += chunk_size - return self.__class__(new_list) # The default buffer prototype used throughout the Zarr codebase. diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 32a4a08642..651c8c1796 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -196,6 +196,24 @@ def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: core.DelayedBuffer.__init__(self, array) self._data_list = list(map(np.asanyarray, self._data_list)) + @classmethod + def create_zero_length(cls) -> Self: + return cls(np.array([], dtype="b")) + + @classmethod + def from_buffer(cls, buffer: core.Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + @classmethod + def from_bytes(cls, bytes_like: BytesLike) -> Self: + return cls(np.asarray(bytes_like, dtype="b")) + + def as_numpy_array(self) -> npt.NDArray[Any]: + return np.asanyarray(self._data) + Buffer.Delayed = DelayedBuffer diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index 473f8ffdd4..333f6440a5 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -229,6 +229,24 @@ def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: core.DelayedBuffer.__init__(self, array) self._data_list = list(map(cp.asarray, self._data_list)) + @classmethod + def create_zero_length(cls) -> Self: + return cls(np.array([], dtype="b")) + + @classmethod + def from_buffer(cls, buffer: core.Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + @classmethod + def from_bytes(cls, bytes_like: BytesLike) -> Self: + return cls(np.asarray(bytes_like, dtype="b")) + + def as_numpy_array(self) -> npt.NDArray[Any]: + return np.asanyarray(self._data) + Buffer.Delayed = DelayedBuffer