diff --git a/pyproject.toml b/pyproject.toml index be99e07..791ebd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "cachetools>=7.0.0", "obspec", "obstore", ] diff --git a/src/obspec_utils/kyle/__init__.py b/src/obspec_utils/kyle/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/obspec_utils/kyle/_block_cache.py b/src/obspec_utils/kyle/_block_cache.py new file mode 100644 index 0000000..738d471 --- /dev/null +++ b/src/obspec_utils/kyle/_block_cache.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Protocol + +from cachetools import LRUCache +from obspec import GetRange, GetRangeAsync, GetRanges, GetRangesAsync + +if TYPE_CHECKING: + from collections.abc import Buffer, Sequence + + +class GetRangeAndGetRanges(GetRange, GetRanges, Protocol): + """Protocol for backends supporting both GetRange and GetRanges.""" + + pass + + +class GetRangeAsyncAndGetRangesAsync(GetRangeAsync, GetRangesAsync, Protocol): + """Protocol for backends supporting both GetRangeAsync and GetRangesAsync.""" + + pass + + +@dataclass +class MemoryCache: + """Block-aligned LRU memory cache for remote data.""" + + block_size: int = 4 * 1024 * 1024 # 4 MiB + max_blocks: int = 128 # 512 MiB default + + # (path, block_index) -> block_data (may be smaller than block_size at EOF) + _blocks: LRUCache[tuple[str, int], bytes] = field(init=False) + + def __post_init__(self) -> None: + self._blocks = LRUCache(maxsize=self.max_blocks) + + def _block_index(self, offset: int) -> int: + """Which block contains this byte offset.""" + return offset // self.block_size + + def _block_start(self, block_idx: int) -> int: + """Starting byte offset of a block.""" + return block_idx * self.block_size + + def get(self, path: str, start: int, end: int) -> bytes | list[tuple[int, int]]: + """Get data from cache, or return missing ranges to fetch. + + Returns: + bytes if fully cached, or list of (start, end) ranges that need fetching. + Missing ranges are block-aligned and coalesced based on COALESCE_BLOCKS. + """ + start_block = self._block_index(start) + end_block = self._block_index(end - 1) # -1 because end is exclusive + + # First pass: identify which blocks are missing + missing_blocks: list[int] = [] + hit_eof = False + + for block_idx in range(start_block, end_block + 1): + key = (path, block_idx) + if key not in self._blocks: + if not hit_eof: + missing_blocks.append(block_idx) + else: + # Check if this cached block is partial (EOF marker) + if len(self._blocks[key]) < self.block_size: + hit_eof = True + + if missing_blocks: + return self._coalesce_missing_blocks(missing_blocks) + + # All blocks cached - assemble result + result = bytearray(end - start) + result_offset = 0 + + for block_idx in range(start_block, end_block + 1): + block_data = self._blocks[(path, block_idx)] + block_start = self._block_start(block_idx) + + # Calculate slice within this block + slice_start = max(0, start - block_start) + slice_end = min(len(block_data), end - block_start) + chunk = block_data[slice_start:slice_end] + + result[result_offset : result_offset + len(chunk)] = chunk + result_offset += len(chunk) + + # If this block is smaller than block_size, we hit EOF + if len(block_data) < self.block_size: + break + + # Truncate if we hit EOF before filling the buffer + return bytes(result[:result_offset]) + + def _coalesce_missing_blocks( + self, missing_blocks: list[int] + ) -> list[tuple[int, int]]: + """Coalesce consecutive missing blocks into ranges. + + Adjacent missing blocks are always coalesced. Non-adjacent missing blocks + (with cached blocks in between) are kept as separate ranges to avoid + re-fetching cached data. + """ + if not missing_blocks: + return [] + + ranges: list[tuple[int, int]] = [] + range_start = missing_blocks[0] + range_end = missing_blocks[0] + + for block_idx in missing_blocks[1:]: + # Only coalesce if blocks are adjacent (gap of 1 means consecutive) + if block_idx - range_end == 1: + range_end = block_idx + else: + # There's a gap (cached block in between), start new range + ranges.append( + ( + self._block_start(range_start), + self._block_start(range_end + 1), + ) + ) + range_start = block_idx + range_end = block_idx + + # Don't forget the last range + ranges.append( + ( + self._block_start(range_start), + self._block_start(range_end + 1), + ) + ) + + return ranges + + def store(self, path: str, fetch_start: int, data: Buffer) -> None: + """Store fetched data as blocks. fetch_start must be block-aligned. + + The last block may be smaller than block_size if we hit EOF. + """ + assert fetch_start % self.block_size == 0, "fetch_start must be block-aligned" + + data_bytes = bytes(data) + offset = 0 + block_idx = fetch_start // self.block_size + + while offset < len(data_bytes): + block_data = data_bytes[offset : offset + self.block_size] + self._blocks[(path, block_idx)] = block_data + offset += self.block_size + block_idx += 1 + + +@dataclass +class SyncBlockCache: + """Synchronous block cache wrapping a GetRange backend.""" + + backend: GetRangeAndGetRanges + cache: MemoryCache = field(default_factory=MemoryCache) + + def get_range( + self, + path: str, + *, + start: int, + end: int | None = None, + length: int | None = None, + ) -> bytes: + if end is None: + if length is None: + raise ValueError("Either end or length must be provided") + end = start + length + + result = self.cache.get(path, start, end) + if isinstance(result, list): + # result is list of missing ranges - fetch them + self._fetch_missing(path, result) + # Now should be cached + result = self.cache.get(path, start, end) + assert isinstance(result, bytes) + + return result + + def get_ranges( + self, + path: str, + *, + starts: Sequence[int], + ends: Sequence[int] | None = None, + lengths: Sequence[int] | None = None, + ) -> Sequence[bytes]: + """Return the bytes stored at the specified location in the given byte ranges.""" + if ends is None: + if lengths is None: + raise ValueError("Either ends or lengths must be provided") + ends = [s + length for s, length in zip(starts, lengths)] + + # Collect all missing ranges across all requests + all_missing: list[tuple[int, int]] = [] + for start, end in zip(starts, ends): + result = self.cache.get(path, start, end) + if isinstance(result, list): + all_missing.extend(result) + + # Fetch all missing ranges in one batch + if all_missing: + self._fetch_missing(path, all_missing) + + # Now all should be cached - collect results + results: list[bytes] = [] + for start, end in zip(starts, ends): + result = self.cache.get(path, start, end) + assert isinstance(result, bytes) + results.append(result) + + return results + + def _fetch_missing(self, path: str, ranges: list[tuple[int, int]]) -> None: + """Fetch missing ranges from backend and store in cache.""" + if len(ranges) == 1: + start, end = ranges[0] + data = self.backend.get_range(path, start=start, end=end) + self.cache.store(path, start, data) + else: + starts = [r[0] for r in ranges] + ends = [r[1] for r in ranges] + buffers: Sequence[Buffer] = self.backend.get_ranges( + path, starts=starts, ends=ends + ) + for (range_start, _), data in zip(ranges, buffers): + self.cache.store(path, range_start, data) + + +@dataclass +class AsyncBlockCache(GetRangeAsync, GetRangesAsync): + """Async block cache wrapping a GetRangeAsync backend.""" + + backend: GetRangeAsyncAndGetRangesAsync + cache: MemoryCache = field(default_factory=MemoryCache) + + async def get_range_async( + self, + path: str, + *, + start: int, + end: int | None = None, + length: int | None = None, + ) -> bytes: + if end is None: + if length is None: + raise ValueError("Either end or length must be provided") + end = start + length + + result = self.cache.get(path, start, end) + if isinstance(result, list): + # result is list of missing ranges - fetch them + await self._fetch_missing(path, result) + # Now should be cached + result = self.cache.get(path, start, end) + assert isinstance(result, bytes) + + return result + + async def get_ranges_async( + self, + path: str, + *, + starts: Sequence[int], + ends: Sequence[int] | None = None, + lengths: Sequence[int] | None = None, + ) -> Sequence[bytes]: + """Return the bytes stored at the specified location in the given byte ranges.""" + if ends is None: + if lengths is None: + raise ValueError("Either ends or lengths must be provided") + ends = [s + length for s, length in zip(starts, lengths)] + + # Collect all missing ranges across all requests + all_missing: list[tuple[int, int]] = [] + for start, end in zip(starts, ends): + result = self.cache.get(path, start, end) + if isinstance(result, list): + all_missing.extend(result) + + # Fetch all missing ranges in one batch + if all_missing: + await self._fetch_missing(path, all_missing) + + # Now all should be cached - collect results + results: list[bytes] = [] + for start, end in zip(starts, ends): + result = self.cache.get(path, start, end) + assert isinstance(result, bytes) + results.append(result) + + return results + + async def _fetch_missing(self, path: str, ranges: list[tuple[int, int]]) -> None: + """Fetch missing ranges from backend and store in cache.""" + if len(ranges) == 1: + start, end = ranges[0] + data = await self.backend.get_range_async(path, start=start, end=end) + self.cache.store(path, start, data) + else: + starts = [r[0] for r in ranges] + ends = [r[1] for r in ranges] + buffers: Sequence[Buffer] = await self.backend.get_ranges_async( + path, starts=starts, ends=ends + ) + for (range_start, _), data in zip(ranges, buffers): + self.cache.store(path, range_start, data) diff --git a/tests/conftest.py b/tests/conftest.py index 8b74e2c..8c34444 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,6 @@ from pathlib import Path import pytest -import xarray as xr def pytest_addoption(parser): @@ -105,6 +104,8 @@ def minio_bucket(container): @pytest.fixture def local_netcdf4_file(tmp_path: Path) -> str: """Create a NetCDF4 file with data in multiple groups.""" + import xarray as xr + filepath = tmp_path / "test.nc" ds1 = xr.DataArray([1, 2, 3], name="foo").to_dataset() ds1.to_netcdf(filepath) diff --git a/tests/kyle/__init__.py b/tests/kyle/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/kyle/test_block_cache.py b/tests/kyle/test_block_cache.py new file mode 100644 index 0000000..92b2126 --- /dev/null +++ b/tests/kyle/test_block_cache.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import obstore as obs + +from obspec_utils.kyle._block_cache import MemoryCache, SyncBlockCache + + +class TestMemoryCache: + def test_block_index(self) -> None: + cache = MemoryCache(block_size=1024) + assert cache._block_index(0) == 0 + assert cache._block_index(1023) == 0 + assert cache._block_index(1024) == 1 + assert cache._block_index(2048) == 2 + + def test_store_and_retrieve_single_block(self) -> None: + cache = MemoryCache(block_size=1024) + data = b"x" * 500 + cache.store("test.bin", 0, data) + + # Should be able to retrieve exact data + result = cache.get("test.bin", 0, 500) + assert result == data + + # Should be able to retrieve a slice + result = cache.get("test.bin", 100, 200) + assert result == b"x" * 100 + + def test_store_and_retrieve_multiple_blocks(self) -> None: + cache = MemoryCache(block_size=1024) + # Store 3 full blocks + data = b"a" * 1024 + b"b" * 1024 + b"c" * 1024 + cache.store("test.bin", 0, data) + + # Retrieve spanning blocks + result = cache.get("test.bin", 1000, 2048) + assert result == b"a" * 24 + b"b" * 1024 + + # Retrieve from middle of one block to middle of another + result = cache.get("test.bin", 512, 2560) + assert result == b"a" * 512 + b"b" * 1024 + b"c" * 512 + + def test_get_returns_missing_ranges_for_uncached(self) -> None: + cache = MemoryCache(block_size=1024) + result = cache.get("test.bin", 0, 100) + assert isinstance(result, list) + assert result == [(0, 1024)] + + def test_get_returns_missing_ranges_spanning_blocks(self) -> None: + cache = MemoryCache(block_size=1024) + result = cache.get("test.bin", 0, 2000) + assert isinstance(result, list) + assert result == [(0, 2048)] + + def test_get_returns_missing_for_partial_cache(self) -> None: + cache = MemoryCache(block_size=1024) + # Store first block only + cache.store("test.bin", 0, b"x" * 1024) + + # Request spanning into uncached block + result = cache.get("test.bin", 0, 2000) + assert isinstance(result, list) + assert result == [(1024, 2048)] + + def test_partial_block_at_eof(self) -> None: + cache = MemoryCache(block_size=1024) + # Store a partial block (simulating EOF) + data = b"x" * 500 + cache.store("test.bin", 0, data) + + # Should retrieve the partial data + result = cache.get("test.bin", 0, 1000) + assert result == data + + def test_lru_eviction(self) -> None: + cache = MemoryCache(block_size=1024, max_blocks=2) + + cache.store("a.bin", 0, b"a" * 1024) + cache.store("b.bin", 0, b"b" * 1024) + + # Both should be cached + assert isinstance(cache.get("a.bin", 0, 1024), bytes) + assert isinstance(cache.get("b.bin", 0, 1024), bytes) + + # Add a third, should evict least recently used + cache.store("c.bin", 0, b"c" * 1024) + + assert isinstance(cache.get("c.bin", 0, 1024), bytes) + # One of a or b should be evicted + assert len(cache._blocks) == 2 + + def test_coalesce_adjacent_missing_blocks(self) -> None: + cache = MemoryCache(block_size=1024) + # Request 10 blocks worth of data - should coalesce into one range + result = cache.get("test.bin", 0, 10 * 1024) + assert isinstance(result, list) + assert result == [(0, 10 * 1024)] + + def test_coalesce_splits_on_cached_block(self) -> None: + cache = MemoryCache(block_size=1024) + # Cache blocks 0 and 5, leave 1-4 missing and 6 missing + cache.store("test.bin", 0, b"x" * 1024) + cache.store("test.bin", 5 * 1024, b"y" * 1024) + + # Request blocks 1-4 and 6 (missing) + result = cache.get("test.bin", 1024, 7 * 1024) + assert isinstance(result, list) + # Blocks 1-4 are missing (consecutive), block 5 is cached, block 6 is missing + # Should split into two ranges to avoid re-fetching block 5 + assert result == [(1024, 5 * 1024), (6 * 1024, 7 * 1024)] + + def test_coalesce_with_single_cached_block_gap(self) -> None: + cache = MemoryCache(block_size=1024) + # Cache block 3, leave 0-2 and 4+ missing + cache.store("test.bin", 3 * 1024, b"x" * 1024) + + # Request blocks 0-10 + result = cache.get("test.bin", 0, 11 * 1024) + assert isinstance(result, list) + # Blocks 0-2 missing, 3 cached, 4-10 missing + # Should split into two ranges + assert result == [(0, 3 * 1024), (4 * 1024, 11 * 1024)] + + def test_coalesce_multiple_separate_ranges(self) -> None: + cache = MemoryCache(block_size=1024) + # Create a pattern with cached blocks: cache blocks 0, 10, 20 + cache.store("test.bin", 0, b"a" * 1024) + cache.store("test.bin", 10 * 1024, b"b" * 1024) + cache.store("test.bin", 20 * 1024, b"c" * 1024) + + # Request everything from 0 to 25*1024 + result = cache.get("test.bin", 0, 25 * 1024) + assert isinstance(result, list) + # Missing: 1-9, 11-19, 21-24 + # Each group of consecutive missing blocks becomes a separate range + assert len(result) == 3 + assert result[0] == (1 * 1024, 10 * 1024) # blocks 1-9 + assert result[1] == (11 * 1024, 20 * 1024) # blocks 11-19 + assert result[2] == (21 * 1024, 25 * 1024) # blocks 21-24 + + +class TestSyncBlockCache: + def test_basic_get_range(self) -> None: + store = obs.store.MemoryStore() + data = b"hello world, this is test data!" + obs.put(store, "test.txt", data) + + cache = SyncBlockCache(backend=store, cache=MemoryCache(block_size=16)) + result = cache.get_range("test.txt", start=0, length=5) + assert result == b"hello" + + def test_caching_avoids_refetch(self) -> None: + store = obs.store.MemoryStore() + data = b"x" * 100 + obs.put(store, "test.bin", data) + + cache = SyncBlockCache(backend=store, cache=MemoryCache(block_size=64)) + + # First fetch + result1 = cache.get_range("test.bin", start=0, length=10) + assert result1 == b"x" * 10 + + # Modify the underlying store + obs.put(store, "test.bin", b"y" * 100) + + # Should still get cached data + result2 = cache.get_range("test.bin", start=0, length=10) + assert result2 == b"x" * 10 + + # But a different file should fetch fresh + obs.put(store, "other.bin", b"z" * 100) + result3 = cache.get_range("other.bin", start=0, length=10) + assert result3 == b"z" * 10 + + def test_range_spanning_blocks(self) -> None: + store = obs.store.MemoryStore() + data = b"a" * 32 + b"b" * 32 + b"c" * 32 + obs.put(store, "test.bin", data) + + cache = SyncBlockCache(backend=store, cache=MemoryCache(block_size=32)) + + # Fetch spanning first two blocks + result = cache.get_range("test.bin", start=16, end=48) + assert result == b"a" * 16 + b"b" * 16 + + def test_eof_handling(self) -> None: + store = obs.store.MemoryStore() + data = b"short" + obs.put(store, "test.bin", data) + + cache = SyncBlockCache(backend=store, cache=MemoryCache(block_size=1024)) + + # Request more than file size - should return what's available + result = cache.get_range("test.bin", start=0, length=100) + assert result == b"short" + + def test_end_vs_length(self) -> None: + store = obs.store.MemoryStore() + data = b"0123456789" + obs.put(store, "test.bin", data) + + cache = SyncBlockCache(backend=store, cache=MemoryCache(block_size=1024)) + + # Using end + result1 = cache.get_range("test.bin", start=2, end=5) + assert result1 == b"234" + + # Using length + result2 = cache.get_range("test.bin", start=2, length=3) + assert result2 == b"234" + + def test_shared_cache(self) -> None: + store = obs.store.MemoryStore() + obs.put(store, "test.bin", b"x" * 100) + + shared_cache = MemoryCache(block_size=64) + cache1 = SyncBlockCache(backend=store, cache=shared_cache) + cache2 = SyncBlockCache(backend=store, cache=shared_cache) + + # Fetch via cache1 + cache1.get_range("test.bin", start=0, length=10) + + # Modify store + obs.put(store, "test.bin", b"y" * 100) + + # cache2 should see cached data from cache1 + result = cache2.get_range("test.bin", start=0, length=10) + assert result == b"x" * 10 + + def test_multiple_ranges_fetch(self) -> None: + store = obs.store.MemoryStore() + # Create data spanning many blocks + data = b"".join(bytes([i] * 32) for i in range(20)) # 20 blocks of 32 bytes + obs.put(store, "test.bin", data) + + shared_cache = MemoryCache(block_size=32) + cache = SyncBlockCache(backend=store, cache=shared_cache) + + # Pre-cache blocks 0, 10 to create gaps + cache.get_range("test.bin", start=0, length=32) + cache.get_range("test.bin", start=10 * 32, length=32) + + # Now request a range spanning the gap - should use get_ranges + result = cache.get_range("test.bin", start=0, end=15 * 32) + + # Verify we got correct data + expected = b"".join(bytes([i] * 32) for i in range(15)) + assert result == expected