|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import io |
4 | | - |
5 | | - |
| 4 | +from collections import OrderedDict |
6 | 5 | from typing import Protocol, runtime_checkable |
7 | 6 |
|
8 | 7 | from obspec import ( |
@@ -266,8 +265,201 @@ def tell(self) -> int: |
266 | 265 | return self._buffer.tell() |
267 | 266 |
|
268 | 267 |
|
| 268 | +class ParallelStoreReader: |
| 269 | + """ |
| 270 | + A file-like reader that uses parallel range requests for efficient chunk fetching. |
| 271 | +
|
| 272 | + This reader divides the file into fixed-size chunks and uses `get_ranges()` to |
| 273 | + fetch multiple chunks in parallel. An LRU cache stores recently accessed chunks |
| 274 | + to avoid redundant fetches. |
| 275 | +
|
| 276 | + This is particularly efficient for workloads that access multiple non-contiguous |
| 277 | + regions of a file, such as reading Zarr/HDF5 datasets. |
| 278 | +
|
| 279 | + Works with any ReadableStore protocol implementation. |
| 280 | + """ |
| 281 | + |
| 282 | + def __init__( |
| 283 | + self, |
| 284 | + store: ReadableStore, |
| 285 | + path: str, |
| 286 | + chunk_size: int = 256 * 1024, |
| 287 | + max_cached_chunks: int = 64, |
| 288 | + ) -> None: |
| 289 | + """ |
| 290 | + Create a parallel reader with chunk-based caching. |
| 291 | +
|
| 292 | + Parameters |
| 293 | + ---------- |
| 294 | + store |
| 295 | + Any object implementing the [ReadableStore][obspec_utils.obspec.ReadableStore] protocol. |
| 296 | + path |
| 297 | + The path to the file within the store. |
| 298 | + chunk_size |
| 299 | + Size of each chunk in bytes. Smaller chunks mean more granular caching |
| 300 | + but potentially more requests. |
| 301 | + max_cached_chunks |
| 302 | + Maximum number of chunks to keep in the LRU cache. |
| 303 | + """ |
| 304 | + self._store = store |
| 305 | + self._path = path |
| 306 | + self._chunk_size = chunk_size |
| 307 | + self._max_cached_chunks = max_cached_chunks |
| 308 | + self._position = 0 |
| 309 | + self._size: int | None = None |
| 310 | + # LRU cache: OrderedDict with chunk_index -> bytes |
| 311 | + self._cache: OrderedDict[int, bytes] = OrderedDict() |
| 312 | + |
| 313 | + def _get_size(self) -> int: |
| 314 | + """Lazily fetch the file size via a get() call.""" |
| 315 | + if self._size is None: |
| 316 | + result = self._store.get(self._path) |
| 317 | + self._size = result.meta["size"] |
| 318 | + return self._size |
| 319 | + |
| 320 | + def _get_chunks(self, chunk_indices: list[int]) -> dict[int, bytes]: |
| 321 | + """Fetch multiple chunks in parallel using get_ranges().""" |
| 322 | + # Filter out already cached chunks |
| 323 | + needed = [i for i in chunk_indices if i not in self._cache] |
| 324 | + |
| 325 | + if needed: |
| 326 | + file_size = self._get_size() |
| 327 | + starts = [] |
| 328 | + lengths = [] |
| 329 | + |
| 330 | + for chunk_idx in needed: |
| 331 | + start = chunk_idx * self._chunk_size |
| 332 | + # Handle last chunk which may be smaller |
| 333 | + end = min(start + self._chunk_size, file_size) |
| 334 | + starts.append(start) |
| 335 | + lengths.append(end - start) |
| 336 | + |
| 337 | + # Fetch all chunks in parallel |
| 338 | + results = self._store.get_ranges(self._path, starts=starts, lengths=lengths) |
| 339 | + |
| 340 | + # Store in cache |
| 341 | + for chunk_idx, data in zip(needed, results): |
| 342 | + self._cache[chunk_idx] = bytes(data) |
| 343 | + # Move to end (most recently used) |
| 344 | + self._cache.move_to_end(chunk_idx) |
| 345 | + |
| 346 | + # Evict oldest if over capacity |
| 347 | + while len(self._cache) > self._max_cached_chunks: |
| 348 | + self._cache.popitem(last=False) |
| 349 | + |
| 350 | + # Return requested chunks from cache |
| 351 | + return {i: self._cache[i] for i in chunk_indices} |
| 352 | + |
| 353 | + def read(self, size: int = -1, /) -> bytes: |
| 354 | + """ |
| 355 | + Read up to `size` bytes from the file. |
| 356 | +
|
| 357 | + Parameters |
| 358 | + ---------- |
| 359 | + size |
| 360 | + Number of bytes to read. If -1, read the entire file. |
| 361 | +
|
| 362 | + Returns |
| 363 | + ------- |
| 364 | + bytes |
| 365 | + The data read from the file. |
| 366 | + """ |
| 367 | + if size == -1: |
| 368 | + return self.readall() |
| 369 | + |
| 370 | + file_size = self._get_size() |
| 371 | + |
| 372 | + # Clamp to remaining bytes |
| 373 | + remaining = file_size - self._position |
| 374 | + if size > remaining: |
| 375 | + size = remaining |
| 376 | + if size <= 0: |
| 377 | + return b"" |
| 378 | + |
| 379 | + # Determine which chunks we need |
| 380 | + start_chunk = self._position // self._chunk_size |
| 381 | + end_pos = self._position + size |
| 382 | + end_chunk = (end_pos - 1) // self._chunk_size |
| 383 | + |
| 384 | + chunk_indices = list(range(start_chunk, end_chunk + 1)) |
| 385 | + chunks = self._get_chunks(chunk_indices) |
| 386 | + |
| 387 | + # Assemble the result |
| 388 | + result = io.BytesIO() |
| 389 | + for chunk_idx in chunk_indices: |
| 390 | + chunk_data = chunks[chunk_idx] |
| 391 | + chunk_start = chunk_idx * self._chunk_size |
| 392 | + |
| 393 | + # Calculate slice within this chunk |
| 394 | + local_start = max(0, self._position - chunk_start) |
| 395 | + local_end = min(len(chunk_data), end_pos - chunk_start) |
| 396 | + |
| 397 | + result.write(chunk_data[local_start:local_end]) |
| 398 | + |
| 399 | + data = result.getvalue() |
| 400 | + self._position += len(data) |
| 401 | + return data |
| 402 | + |
| 403 | + def readall(self) -> bytes: |
| 404 | + """ |
| 405 | + Read the entire file. |
| 406 | +
|
| 407 | + Returns |
| 408 | + ------- |
| 409 | + bytes |
| 410 | + The complete file contents. |
| 411 | + """ |
| 412 | + result = self._store.get(self._path) |
| 413 | + data = bytes(result.buffer()) |
| 414 | + self._size = len(data) |
| 415 | + self._position = len(data) |
| 416 | + return data |
| 417 | + |
| 418 | + def seek(self, offset: int, whence: int = 0, /) -> int: |
| 419 | + """ |
| 420 | + Move the file position. |
| 421 | +
|
| 422 | + Parameters |
| 423 | + ---------- |
| 424 | + offset |
| 425 | + Position offset. |
| 426 | + whence |
| 427 | + Reference point: 0=start (SEEK_SET), 1=current (SEEK_CUR), 2=end (SEEK_END). |
| 428 | +
|
| 429 | + Returns |
| 430 | + ------- |
| 431 | + int |
| 432 | + The new absolute position. |
| 433 | + """ |
| 434 | + if whence == 0: # SEEK_SET |
| 435 | + self._position = offset |
| 436 | + elif whence == 1: # SEEK_CUR |
| 437 | + self._position += offset |
| 438 | + elif whence == 2: # SEEK_END |
| 439 | + self._position = self._get_size() + offset |
| 440 | + else: |
| 441 | + raise ValueError(f"Invalid whence value: {whence}") |
| 442 | + |
| 443 | + if self._position < 0: |
| 444 | + self._position = 0 |
| 445 | + |
| 446 | + return self._position |
| 447 | + |
| 448 | + def tell(self) -> int: |
| 449 | + """ |
| 450 | + Return the current file position. |
| 451 | +
|
| 452 | + Returns |
| 453 | + ------- |
| 454 | + int |
| 455 | + Current position in bytes from start of file. |
| 456 | + """ |
| 457 | + return self._position |
| 458 | + |
| 459 | + |
269 | 460 | __all__: list[str] = [ |
270 | 461 | "ReadableStore", |
271 | 462 | "BufferedStoreReader", |
272 | 463 | "EagerStoreReader", |
| 464 | + "ParallelStoreReader", |
273 | 465 | ] |
0 commit comments