From 85db0e53a2b2321ee3f3e1e15231bf00937aa2c4 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 18:07:54 +0000 Subject: [PATCH 1/3] Add four new resilience and versioning wrappers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements four new protocol-compliant wrappers based on industry patterns from resilience frameworks and caching libraries: 1. CircuitBreakerWrapper - Prevents cascading failures by blocking requests to failing backends - Implements the circuit breaker pattern with CLOSED/OPEN/HALF_OPEN states - Configurable failure threshold, recovery timeout, and success threshold - Essential for production resilience 2. RateLimitWrapper - Protects backends from overload with request throttling - Supports both sliding and fixed window strategies - Configurable request limits and time windows - Uses asyncio primitives for concurrent request tracking 3. VersioningWrapper - Enables schema evolution with automatic version tagging - Auto-invalidates cache entries with mismatched versions - Stores version metadata within value dict (similar to CompressionWrapper) - Useful for deployment coordination and schema migration 4. BulkheadWrapper - Isolates operations with bounded resource pools - Limits concurrent operations using asyncio.Semaphore - Configurable queue size to prevent unbounded growth - Prevents resource exhaustion and enables graceful degradation All wrappers: - Follow existing patterns (RetryWrapper, TimeoutWrapper, CompressionWrapper) - Include comprehensive test coverage - Use proper TypeVar and Callable typing for type safety - Have detailed docstrings with usage examples 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: William Easton --- .../aio/wrappers/bulkhead/__init__.py | 3 + .../aio/wrappers/bulkhead/wrapper.py | 129 +++++++++++ .../aio/wrappers/circuit_breaker/__init__.py | 3 + .../aio/wrappers/circuit_breaker/wrapper.py | 172 +++++++++++++++ .../aio/wrappers/rate_limit/__init__.py | 3 + .../aio/wrappers/rate_limit/wrapper.py | 162 ++++++++++++++ .../aio/wrappers/versioning/__init__.py | 3 + .../aio/wrappers/versioning/wrapper.py | 124 +++++++++++ .../tests/stores/wrappers/test_bulkhead.py | 178 ++++++++++++++++ .../stores/wrappers/test_circuit_breaker.py | 200 ++++++++++++++++++ .../tests/stores/wrappers/test_rate_limit.py | 132 ++++++++++++ .../tests/stores/wrappers/test_versioning.py | 180 ++++++++++++++++ .../shared/errors/wrappers/bulkhead.py | 11 + .../shared/errors/wrappers/circuit_breaker.py | 11 + .../shared/errors/wrappers/rate_limit.py | 11 + 15 files changed, 1322 insertions(+) create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/__init__.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/__init__.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/__init__.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/versioning/__init__.py create mode 100644 key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py create mode 100644 key-value/key-value-aio/tests/stores/wrappers/test_bulkhead.py create mode 100644 key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py create mode 100644 key-value/key-value-aio/tests/stores/wrappers/test_rate_limit.py create mode 100644 key-value/key-value-aio/tests/stores/wrappers/test_versioning.py create mode 100644 key-value/key-value-shared/src/key_value/shared/errors/wrappers/bulkhead.py create mode 100644 key-value/key-value-shared/src/key_value/shared/errors/wrappers/circuit_breaker.py create mode 100644 key-value/key-value-shared/src/key_value/shared/errors/wrappers/rate_limit.py diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/__init__.py b/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/__init__.py new file mode 100644 index 00000000..188558db --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/__init__.py @@ -0,0 +1,3 @@ +from key_value.aio.wrappers.bulkhead.wrapper import BulkheadWrapper + +__all__ = ["BulkheadWrapper"] diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py new file mode 100644 index 00000000..1aa7f730 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py @@ -0,0 +1,129 @@ +import asyncio +from collections.abc import Callable, Coroutine, Mapping, Sequence +from typing import Any, SupportsFloat, TypeVar + +from key_value.shared.errors.wrappers.bulkhead import BulkheadFullError +from typing_extensions import override + +from key_value.aio.protocols.key_value import AsyncKeyValue +from key_value.aio.wrappers.base import BaseWrapper + +T = TypeVar("T") + + +class BulkheadWrapper(BaseWrapper): + """Wrapper that implements the bulkhead pattern to isolate operations with resource pools. + + This wrapper limits the number of concurrent operations and queued operations to prevent + resource exhaustion and isolate failures. The bulkhead pattern is inspired by ship bulkheads + that prevent a single hull breach from sinking the entire ship. + + Benefits: + - Prevents a single slow or failing backend from consuming all resources + - Limits concurrent requests to protect backend from overload + - Provides bounded queue to prevent unbounded memory growth + - Enables graceful degradation under high load + + Example: + bulkhead = BulkheadWrapper( + key_value=store, + max_concurrent=10, # Max 10 concurrent operations + max_waiting=20, # Max 20 operations can wait in queue + ) + + try: + await bulkhead.get(key="mykey") + except BulkheadFullError: + # Too many concurrent operations, system is overloaded + # Handle gracefully (return cached value, error response, etc.) + pass + """ + + def __init__( + self, + key_value: AsyncKeyValue, + max_concurrent: int = 10, + max_waiting: int = 20, + ) -> None: + """Initialize the bulkhead wrapper. + + Args: + key_value: The store to wrap. + max_concurrent: Maximum number of concurrent operations. Defaults to 10. + max_waiting: Maximum number of operations that can wait in queue. Defaults to 20. + """ + self.key_value: AsyncKeyValue = key_value + self.max_concurrent: int = max_concurrent + self.max_waiting: int = max_waiting + + # Use semaphore to limit concurrent operations + self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrent) + self._waiting_count: int = 0 + self._waiting_lock: asyncio.Lock = asyncio.Lock() + + super().__init__() + + async def _execute_with_bulkhead(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T: + """Execute an operation with bulkhead resource limiting.""" + # Check if we can accept this operation + async with self._waiting_lock: + if self._waiting_count >= self.max_waiting: + raise BulkheadFullError(max_concurrent=self.max_concurrent, max_waiting=self.max_waiting) + self._waiting_count += 1 + + try: + # Acquire semaphore to limit concurrency + async with self._semaphore: + # Once we have the semaphore, we're no longer waiting + async with self._waiting_lock: + self._waiting_count -= 1 + + # Execute the operation + return await operation(*args, **kwargs) + except Exception: + # Make sure to decrement waiting count if we error before acquiring semaphore + async with self._waiting_lock: + # Only decrement if we're still counted as waiting + # (might have already decremented if we got the semaphore) + if self._waiting_count > 0 and self._semaphore.locked(): + self._waiting_count -= 1 + raise + + @override + async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + return await self._execute_with_bulkhead(self.key_value.get, key=key, collection=collection) + + @override + async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + return await self._execute_with_bulkhead(self.key_value.get_many, keys=keys, collection=collection) + + @override + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + return await self._execute_with_bulkhead(self.key_value.ttl, key=key, collection=collection) + + @override + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + return await self._execute_with_bulkhead(self.key_value.ttl_many, keys=keys, collection=collection) + + @override + async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + return await self._execute_with_bulkhead(self.key_value.put, key=key, value=value, collection=collection, ttl=ttl) + + @override + async def put_many( + self, + keys: Sequence[str], + values: Sequence[Mapping[str, Any]], + *, + collection: str | None = None, + ttl: SupportsFloat | None = None, + ) -> None: + return await self._execute_with_bulkhead(self.key_value.put_many, keys=keys, values=values, collection=collection, ttl=ttl) + + @override + async def delete(self, key: str, *, collection: str | None = None) -> bool: + return await self._execute_with_bulkhead(self.key_value.delete, key=key, collection=collection) + + @override + async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: + return await self._execute_with_bulkhead(self.key_value.delete_many, keys=keys, collection=collection) diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/__init__.py b/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/__init__.py new file mode 100644 index 00000000..66e530f6 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/__init__.py @@ -0,0 +1,3 @@ +from key_value.aio.wrappers.circuit_breaker.wrapper import CircuitBreakerWrapper + +__all__ = ["CircuitBreakerWrapper"] diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py new file mode 100644 index 00000000..b92a1716 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py @@ -0,0 +1,172 @@ +import time +from collections.abc import Callable, Coroutine, Mapping, Sequence +from enum import Enum +from typing import Any, SupportsFloat, TypeVar + +from key_value.shared.errors.wrappers.circuit_breaker import CircuitOpenError +from typing_extensions import override + +from key_value.aio.protocols.key_value import AsyncKeyValue +from key_value.aio.wrappers.base import BaseWrapper + +T = TypeVar("T") + + +class CircuitState(Enum): + """States for the circuit breaker.""" + + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, blocking requests + HALF_OPEN = "half_open" # Testing if service recovered + + +class CircuitBreakerWrapper(BaseWrapper): + """Wrapper that implements the circuit breaker pattern to prevent cascading failures. + + This wrapper tracks operation failures and opens the circuit after a threshold of consecutive + failures. When the circuit is open, requests are blocked immediately without attempting the + operation. After a recovery timeout, the circuit moves to half-open state to test if the + backend has recovered. + + The circuit breaker pattern is essential for production resilience as it: + - Prevents cascading failures when a backend becomes unhealthy + - Reduces load on failing backends, giving them time to recover + - Provides fast failure responses instead of waiting for timeouts + - Automatically attempts recovery after a configured timeout + + Example: + circuit_breaker = CircuitBreakerWrapper( + key_value=store, + failure_threshold=5, # Open after 5 consecutive failures + recovery_timeout=30.0, # Try recovery after 30 seconds + success_threshold=2, # Close after 2 successes in half-open + ) + + try: + value = await circuit_breaker.get(key="mykey") + except CircuitOpenError: + # Circuit is open, backend is considered unhealthy + # Handle gracefully (use cache, return default, etc.) + pass + """ + + def __init__( + self, + key_value: AsyncKeyValue, + failure_threshold: int = 5, + recovery_timeout: float = 30.0, + success_threshold: int = 2, + error_types: tuple[type[Exception], ...] = (Exception,), + ) -> None: + """Initialize the circuit breaker wrapper. + + Args: + key_value: The store to wrap. + failure_threshold: Number of consecutive failures before opening the circuit. Defaults to 5. + recovery_timeout: Seconds to wait before attempting recovery (moving to half-open). Defaults to 30.0. + success_threshold: Number of consecutive successes in half-open state before closing the circuit. Defaults to 2. + error_types: Tuple of exception types that count as failures. Defaults to (Exception,). + """ + self.key_value: AsyncKeyValue = key_value + self.failure_threshold: int = failure_threshold + self.recovery_timeout: float = recovery_timeout + self.success_threshold: int = success_threshold + self.error_types: tuple[type[Exception], ...] = error_types + + # Circuit state + self._state: CircuitState = CircuitState.CLOSED + self._failure_count: int = 0 + self._success_count: int = 0 + self._last_failure_time: float | None = None + + super().__init__() + + def _check_circuit(self) -> None: + """Check the circuit state and potentially transition states.""" + if self._state == CircuitState.OPEN: + # Check if we should move to half-open + if self._last_failure_time is not None and time.time() - self._last_failure_time >= self.recovery_timeout: + self._state = CircuitState.HALF_OPEN + self._success_count = 0 + else: + # Circuit is still open, raise error + raise CircuitOpenError(failure_count=self._failure_count, last_failure_time=self._last_failure_time) + + def _on_success(self) -> None: + """Handle successful operation.""" + if self._state == CircuitState.HALF_OPEN: + self._success_count += 1 + if self._success_count >= self.success_threshold: + # Close the circuit + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._success_count = 0 + elif self._state == CircuitState.CLOSED: + # Reset failure count on success + self._failure_count = 0 + + def _on_failure(self) -> None: + """Handle failed operation.""" + self._last_failure_time = time.time() + + if self._state == CircuitState.HALF_OPEN: + # Failed in half-open, go back to open + self._state = CircuitState.OPEN + self._success_count = 0 + elif self._state == CircuitState.CLOSED: + self._failure_count += 1 + if self._failure_count >= self.failure_threshold: + # Open the circuit + self._state = CircuitState.OPEN + + async def _execute_with_circuit_breaker(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T: + """Execute an operation with circuit breaker logic.""" + self._check_circuit() + + try: + result = await operation(*args, **kwargs) + except self.error_types: + self._on_failure() + raise + else: + self._on_success() + return result + + @override + async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + return await self._execute_with_circuit_breaker(self.key_value.get, key=key, collection=collection) + + @override + async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + return await self._execute_with_circuit_breaker(self.key_value.get_many, keys=keys, collection=collection) + + @override + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + return await self._execute_with_circuit_breaker(self.key_value.ttl, key=key, collection=collection) + + @override + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + return await self._execute_with_circuit_breaker(self.key_value.ttl_many, keys=keys, collection=collection) + + @override + async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + return await self._execute_with_circuit_breaker(self.key_value.put, key=key, value=value, collection=collection, ttl=ttl) + + @override + async def put_many( + self, + keys: Sequence[str], + values: Sequence[Mapping[str, Any]], + *, + collection: str | None = None, + ttl: SupportsFloat | None = None, + ) -> None: + return await self._execute_with_circuit_breaker(self.key_value.put_many, keys=keys, values=values, collection=collection, ttl=ttl) + + @override + async def delete(self, key: str, *, collection: str | None = None) -> bool: + return await self._execute_with_circuit_breaker(self.key_value.delete, key=key, collection=collection) + + @override + async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: + return await self._execute_with_circuit_breaker(self.key_value.delete_many, keys=keys, collection=collection) diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/__init__.py b/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/__init__.py new file mode 100644 index 00000000..c735e320 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/__init__.py @@ -0,0 +1,3 @@ +from key_value.aio.wrappers.rate_limit.wrapper import RateLimitWrapper + +__all__ = ["RateLimitWrapper"] diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py new file mode 100644 index 00000000..565cde0e --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py @@ -0,0 +1,162 @@ +import asyncio +import time +from collections import deque +from collections.abc import Mapping, Sequence +from typing import Any, Literal, SupportsFloat + +from key_value.shared.errors.wrappers.rate_limit import RateLimitExceededError +from typing_extensions import override + +from key_value.aio.protocols.key_value import AsyncKeyValue +from key_value.aio.wrappers.base import BaseWrapper + + +class RateLimitWrapper(BaseWrapper): + """Wrapper that limits the rate of operations to protect backends from overload. + + This wrapper implements rate limiting using a sliding window algorithm to control + the number of requests that can be made within a time window. This is essential for: + - Protecting backends from being overwhelmed by too many requests + - Complying with API rate limits of third-party services + - Ensuring fair resource usage in multi-tenant environments + - Preventing accidental DoS from application bugs + + The wrapper supports two strategies: + - "sliding": More accurate, considers exact timestamps of recent requests + - "fixed": Simpler, resets count at fixed intervals + + Example: + rate_limiter = RateLimitWrapper( + key_value=store, + max_requests=100, # Maximum 100 requests + window_seconds=60.0, # Per 60-second window + strategy="sliding" # Use sliding window + ) + + try: + await rate_limiter.get(key="mykey") + except RateLimitExceededError: + # Too many requests, need to back off + await asyncio.sleep(1) + """ + + def __init__( + self, + key_value: AsyncKeyValue, + max_requests: int = 100, + window_seconds: float = 60.0, + strategy: Literal["sliding", "fixed"] = "sliding", + ) -> None: + """Initialize the rate limit wrapper. + + Args: + key_value: The store to wrap. + max_requests: Maximum number of requests allowed in the time window. Defaults to 100. + window_seconds: Time window in seconds. Defaults to 60.0. + strategy: Rate limiting strategy - "sliding" or "fixed". Defaults to "sliding". + """ + self.key_value: AsyncKeyValue = key_value + self.max_requests: int = max_requests + self.window_seconds: float = window_seconds + self.strategy: Literal["sliding", "fixed"] = strategy + + # For sliding window + self._request_times: deque[float] = deque() + self._lock: asyncio.Lock = asyncio.Lock() + + # For fixed window + self._window_start: float | None = None + self._request_count: int = 0 + + super().__init__() + + async def _check_rate_limit_sliding(self) -> None: + """Check rate limit using sliding window strategy.""" + async with self._lock: + now = time.time() + + # Remove requests outside the current window + while self._request_times and self._request_times[0] < now - self.window_seconds: + self._request_times.popleft() + + # Check if we're at the limit + if len(self._request_times) >= self.max_requests: + raise RateLimitExceededError( + current_requests=len(self._request_times), max_requests=self.max_requests, window_seconds=self.window_seconds + ) + + # Record this request + self._request_times.append(now) + + async def _check_rate_limit_fixed(self) -> None: + """Check rate limit using fixed window strategy.""" + async with self._lock: + now = time.time() + + # Check if we need to start a new window + if self._window_start is None or now >= self._window_start + self.window_seconds: + self._window_start = now + self._request_count = 0 + + # Check if we're at the limit + if self._request_count >= self.max_requests: + raise RateLimitExceededError( + current_requests=self._request_count, max_requests=self.max_requests, window_seconds=self.window_seconds + ) + + # Record this request + self._request_count += 1 + + async def _check_rate_limit(self) -> None: + """Check rate limit based on configured strategy.""" + if self.strategy == "sliding": + await self._check_rate_limit_sliding() + else: + await self._check_rate_limit_fixed() + + @override + async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + await self._check_rate_limit() + return await self.key_value.get(key=key, collection=collection) + + @override + async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + await self._check_rate_limit() + return await self.key_value.get_many(keys=keys, collection=collection) + + @override + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + await self._check_rate_limit() + return await self.key_value.ttl(key=key, collection=collection) + + @override + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + await self._check_rate_limit() + return await self.key_value.ttl_many(keys=keys, collection=collection) + + @override + async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + await self._check_rate_limit() + return await self.key_value.put(key=key, value=value, collection=collection, ttl=ttl) + + @override + async def put_many( + self, + keys: Sequence[str], + values: Sequence[Mapping[str, Any]], + *, + collection: str | None = None, + ttl: SupportsFloat | None = None, + ) -> None: + await self._check_rate_limit() + return await self.key_value.put_many(keys=keys, values=values, collection=collection, ttl=ttl) + + @override + async def delete(self, key: str, *, collection: str | None = None) -> bool: + await self._check_rate_limit() + return await self.key_value.delete(key=key, collection=collection) + + @override + async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: + await self._check_rate_limit() + return await self.key_value.delete_many(keys=keys, collection=collection) diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/__init__.py b/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/__init__.py new file mode 100644 index 00000000..c64a690a --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/__init__.py @@ -0,0 +1,3 @@ +from key_value.aio.wrappers.versioning.wrapper import VersioningWrapper + +__all__ = ["VersioningWrapper"] diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py new file mode 100644 index 00000000..c8a5977c --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py @@ -0,0 +1,124 @@ +from collections.abc import Mapping, Sequence +from typing import Any, SupportsFloat + +from typing_extensions import override + +from key_value.aio.protocols.key_value import AsyncKeyValue +from key_value.aio.wrappers.base import BaseWrapper + +# Special keys used to store version information +_VERSION_KEY = "__version__" +_VERSIONED_DATA_KEY = "__versioned_data__" + + +class VersioningWrapper(BaseWrapper): + """Wrapper that adds version tagging to values for schema evolution and cache invalidation. + + This wrapper automatically tags all stored values with a version identifier. When retrieving + values, it checks the version and returns None for values with mismatched versions, effectively + auto-invalidating old cache entries. + + This is useful for: + - Schema evolution: When your data structure changes, old cached values are automatically invalidated + - Deployment coordination: Different versions of your application can coexist without sharing incompatible cached data + - Safe cache invalidation: Increment the version to invalidate all cached entries without manual cleanup + + The versioned format looks like: + { + "__version__": "v1.2.0", + "__versioned_data__": { + "actual": "user", + "data": "here" + } + } + + Example: + # Version 1 of your application + store_v1 = VersioningWrapper(key_value=store, version="v1") + await store_v1.put(key="user:123", value={"name": "John", "email": "john@example.com"}) + + # Version 2 changes the schema (adds "age" field) + store_v2 = VersioningWrapper(key_value=store, version="v2") + result = await store_v2.get(key="user:123") + # Returns None because version mismatch, forcing reload with new schema + """ + + def __init__( + self, + key_value: AsyncKeyValue, + version: str | int, + ) -> None: + """Initialize the versioning wrapper. + + Args: + key_value: The store to wrap. + version: The version identifier to tag values with. Can be string (e.g., "v1.2.0") or int (e.g., 1). + """ + self.key_value: AsyncKeyValue = key_value + self.version: str | int = version + + super().__init__() + + def _wrap_value(self, value: dict[str, Any]) -> dict[str, Any]: + """Wrap a value with version information.""" + # If already versioned, don't double-wrap + if _VERSION_KEY in value: + return value + + return {_VERSION_KEY: self.version, _VERSIONED_DATA_KEY: value} + + def _unwrap_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None: + """Unwrap a versioned value, returning None if version mismatch.""" + if value is None: + return None + + # Not versioned, return as-is + if _VERSION_KEY not in value: + return value + + # Check version match + if value[_VERSION_KEY] != self.version: + # Version mismatch - auto-invalidate by returning None + return None + + # Extract the actual data + return value.get(_VERSIONED_DATA_KEY, value) + + @override + async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + value = await self.key_value.get(key=key, collection=collection) + return self._unwrap_value(value) + + @override + async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + values = await self.key_value.get_many(keys=keys, collection=collection) + return [self._unwrap_value(value) for value in values] + + @override + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + value, ttl = await self.key_value.ttl(key=key, collection=collection) + unwrapped = self._unwrap_value(value) + # If version mismatch, return None for TTL as well + return unwrapped, ttl if unwrapped is not None else None + + @override + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + results = await self.key_value.ttl_many(keys=keys, collection=collection) + return [(self._unwrap_value(value), ttl if self._unwrap_value(value) is not None else None) for value, ttl in results] + + @override + async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + wrapped_value = self._wrap_value(dict(value)) + return await self.key_value.put(key=key, value=wrapped_value, collection=collection, ttl=ttl) + + @override + async def put_many( + self, + keys: Sequence[str], + values: Sequence[Mapping[str, Any]], + *, + collection: str | None = None, + ttl: SupportsFloat | None = None, + ) -> None: + wrapped_values = [self._wrap_value(dict(value)) for value in values] + return await self.key_value.put_many(keys=keys, values=wrapped_values, collection=collection, ttl=ttl) diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_bulkhead.py b/key-value/key-value-aio/tests/stores/wrappers/test_bulkhead.py new file mode 100644 index 00000000..bfee688e --- /dev/null +++ b/key-value/key-value-aio/tests/stores/wrappers/test_bulkhead.py @@ -0,0 +1,178 @@ +import asyncio + +import pytest +from key_value.shared.errors.wrappers.bulkhead import BulkheadFullError +from typing_extensions import override + +from key_value.aio.stores.memory.store import MemoryStore +from key_value.aio.wrappers.bulkhead import BulkheadWrapper +from tests.stores.base import BaseStoreTests + + +class SlowStore(MemoryStore): + """A store that adds artificial delay to operations.""" + + def __init__(self, delay: float = 0.1): + super().__init__() + self.delay = delay + self.concurrent_operations = 0 + self.max_concurrent_observed = 0 + + async def get(self, key: str, *, collection: str | None = None): + self.concurrent_operations += 1 + self.max_concurrent_observed = max(self.max_concurrent_observed, self.concurrent_operations) + try: + await asyncio.sleep(self.delay) + return await super().get(key=key, collection=collection) + finally: + self.concurrent_operations -= 1 + + +class TestBulkheadWrapper(BaseStoreTests): + @override + @pytest.fixture + async def store(self, memory_store: MemoryStore) -> BulkheadWrapper: + return BulkheadWrapper(key_value=memory_store, max_concurrent=10, max_waiting=20) + + async def test_bulkhead_allows_operations_within_limit(self, memory_store: MemoryStore): + bulkhead = BulkheadWrapper(key_value=memory_store, max_concurrent=5, max_waiting=10) + + # Should allow operations within limits + await bulkhead.put(collection="test", key="key1", value={"value": 1}) + await bulkhead.put(collection="test", key="key2", value={"value": 2}) + result = await bulkhead.get(collection="test", key="key1") + assert result == {"value": 1} + + async def test_bulkhead_limits_concurrent_operations(self): + slow_store = SlowStore(delay=0.1) + bulkhead = BulkheadWrapper(key_value=slow_store, max_concurrent=3, max_waiting=10) + + # Pre-populate store + await slow_store.put(collection="test", key="key", value={"value": 1}) + + # Launch 10 concurrent operations + tasks = [bulkhead.get(collection="test", key="key") for _ in range(10)] + await asyncio.gather(*tasks) + + # Verify that at most 3 operations ran concurrently + assert slow_store.max_concurrent_observed <= 3 + + async def test_bulkhead_blocks_when_queue_full(self): + slow_store = SlowStore(delay=0.5) + bulkhead = BulkheadWrapper(key_value=slow_store, max_concurrent=2, max_waiting=3) + + # Pre-populate store + await slow_store.put(collection="test", key="key", value={"value": 1}) + + # Launch operations that will fill the bulkhead + # 2 will run concurrently, 3 will wait, rest should be rejected + tasks = [bulkhead.get(collection="test", key="key") for _ in range(10)] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Count successes and errors + successes = sum(1 for r in results if not isinstance(r, Exception)) + errors = sum(1 for r in results if isinstance(r, BulkheadFullError)) + + # Should have 5 successes (2 concurrent + 3 waiting) and 5 errors + assert successes == 5 + assert errors == 5 + + async def test_bulkhead_allows_operations_after_completion(self): + slow_store = SlowStore(delay=0.05) + bulkhead = BulkheadWrapper(key_value=slow_store, max_concurrent=2, max_waiting=2) + + # Pre-populate store + await slow_store.put(collection="test", key="key", value={"value": 1}) + + # First batch - should succeed + tasks1 = [bulkhead.get(collection="test", key="key") for _ in range(4)] + results1 = await asyncio.gather(*tasks1, return_exceptions=True) + successes1 = sum(1 for r in results1 if not isinstance(r, Exception)) + assert successes1 == 4 + + # Second batch - should also succeed since first batch completed + tasks2 = [bulkhead.get(collection="test", key="key") for _ in range(4)] + results2 = await asyncio.gather(*tasks2, return_exceptions=True) + successes2 = sum(1 for r in results2 if not isinstance(r, Exception)) + assert successes2 == 4 + + async def test_bulkhead_applies_to_all_operations(self): + slow_store = SlowStore(delay=0.1) + bulkhead = BulkheadWrapper(key_value=slow_store, max_concurrent=2, max_waiting=1) + + # Pre-populate store + await slow_store.put(collection="test", key="key1", value={"value": 1}) + await slow_store.put(collection="test", key="key2", value={"value": 2}) + + # Mix different operations + tasks = [ + bulkhead.get(collection="test", key="key1"), + bulkhead.put(collection="test", key="key3", value={"value": 3}), + bulkhead.delete(collection="test", key="key2"), + bulkhead.get(collection="test", key="key1"), + bulkhead.get(collection="test", key="key1"), # This should be rejected + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Should have 3 successes (2 concurrent + 1 waiting) and 2 errors + successes = sum(1 for r in results if not isinstance(r, Exception)) + errors = sum(1 for r in results if isinstance(r, BulkheadFullError)) + + assert successes == 3 + assert errors == 2 + + async def test_bulkhead_default_parameters(self, memory_store: MemoryStore): + bulkhead = BulkheadWrapper(key_value=memory_store) + + # Should use defaults + assert bulkhead.max_concurrent == 10 + assert bulkhead.max_waiting == 20 + + async def test_bulkhead_error_handling(self): + """Test that errors in operations don't leak semaphore counts.""" + + class FailingStore(MemoryStore): + async def get(self, key: str, *, collection: str | None = None): + msg = "Intentional failure" + raise RuntimeError(msg) + + failing_store = FailingStore() + bulkhead = BulkheadWrapper(key_value=failing_store, max_concurrent=2, max_waiting=2) + + # Execute operations that will fail + for _ in range(5): + with pytest.raises(RuntimeError): + await bulkhead.get(collection="test", key="key") + + # Semaphore should be released properly - we should still be able to make requests + # If semaphore leaked, this would eventually block + assert True + + async def test_bulkhead_with_fast_operations(self, memory_store: MemoryStore): + """Test bulkhead with operations that complete quickly.""" + bulkhead = BulkheadWrapper(key_value=memory_store, max_concurrent=2, max_waiting=2) + + # Pre-populate + await memory_store.put(collection="test", key="key", value={"value": 1}) + + # Fast operations should all succeed even with low limits + tasks = [bulkhead.get(collection="test", key="key") for _ in range(20)] + results = await asyncio.gather(*tasks) + + # All should succeed + assert all(r == {"value": 1} for r in results) + + async def test_bulkhead_sequential_operations(self, memory_store: MemoryStore): + """Test that sequential operations don't count against concurrent limit.""" + bulkhead = BulkheadWrapper(key_value=memory_store, max_concurrent=1, max_waiting=0) + + # Sequential operations should all succeed + for i in range(10): + await bulkhead.put(collection="test", key=f"key{i}", value={"value": i}) + + # All should be stored + for i in range(10): + result = await bulkhead.get(collection="test", key=f"key{i}") + assert result == {"value": i} diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py b/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py new file mode 100644 index 00000000..4f99f847 --- /dev/null +++ b/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py @@ -0,0 +1,200 @@ +import pytest +from key_value.shared.errors.wrappers.circuit_breaker import CircuitOpenError +from typing_extensions import override + +from key_value.aio.stores.memory.store import MemoryStore +from key_value.aio.wrappers.circuit_breaker import CircuitBreakerWrapper +from key_value.aio.wrappers.circuit_breaker.wrapper import CircuitState +from tests.stores.base import BaseStoreTests + + +class IntermittentlyFailingStore(MemoryStore): + """A store that fails a configurable number of times before succeeding.""" + + def __init__(self, failures_before_success: int = 5): + super().__init__() + self.failures_before_success = failures_before_success + self.attempt_count = 0 + + async def get(self, key: str, *, collection: str | None = None): + self.attempt_count += 1 + if self.attempt_count <= self.failures_before_success: + msg = "Simulated connection error" + raise ConnectionError(msg) + return await super().get(key=key, collection=collection) + + def reset_attempts(self): + self.attempt_count = 0 + + +class TestCircuitBreakerWrapper(BaseStoreTests): + @override + @pytest.fixture + async def store(self, memory_store: MemoryStore) -> CircuitBreakerWrapper: + return CircuitBreakerWrapper(key_value=memory_store, failure_threshold=5, recovery_timeout=1.0) + + async def test_circuit_remains_closed_on_success(self, memory_store: MemoryStore): + circuit_breaker = CircuitBreakerWrapper(key_value=memory_store, failure_threshold=3) + + # Successful operations should keep circuit closed + await circuit_breaker.put(collection="test", key="test1", value={"test": "value1"}) + await circuit_breaker.put(collection="test", key="test2", value={"test": "value2"}) + await circuit_breaker.get(collection="test", key="test1") + + assert circuit_breaker._state == CircuitState.CLOSED # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._failure_count == 0 # pyright: ignore[reportPrivateUsage] + + async def test_circuit_opens_after_threshold_failures(self): + failing_store = IntermittentlyFailingStore(failures_before_success=10) + circuit_breaker = CircuitBreakerWrapper( + key_value=failing_store, failure_threshold=3, recovery_timeout=1.0, error_types=(ConnectionError,) + ) + + # First 3 failures should open the circuit + for _ in range(3): + with pytest.raises(ConnectionError): + await circuit_breaker.get(collection="test", key="test") + + assert circuit_breaker._state == CircuitState.OPEN # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._failure_count == 3 # pyright: ignore[reportPrivateUsage] + + # Next attempt should fail immediately with CircuitOpenError + with pytest.raises(CircuitOpenError): + await circuit_breaker.get(collection="test", key="test") + + # Verify we didn't make another attempt to the backend + assert failing_store.attempt_count == 3 + + async def test_circuit_transitions_to_half_open(self): + failing_store = IntermittentlyFailingStore(failures_before_success=10) + circuit_breaker = CircuitBreakerWrapper( + key_value=failing_store, failure_threshold=3, recovery_timeout=0.1, error_types=(ConnectionError,) + ) + + # Open the circuit + for _ in range(3): + with pytest.raises(ConnectionError): + await circuit_breaker.get(collection="test", key="test") + + assert circuit_breaker._state == CircuitState.OPEN # pyright: ignore[reportPrivateUsage] + + # Wait for recovery timeout + import asyncio + + await asyncio.sleep(0.15) + + # Next attempt should transition to half-open and try the operation + with pytest.raises(ConnectionError): + await circuit_breaker.get(collection="test", key="test") + + # Should be back to open since it failed + assert circuit_breaker._state == CircuitState.OPEN # pyright: ignore[reportPrivateUsage] + + async def test_circuit_closes_after_successful_recovery(self, memory_store: MemoryStore): + failing_store = IntermittentlyFailingStore(failures_before_success=3) + circuit_breaker = CircuitBreakerWrapper( + key_value=failing_store, + failure_threshold=3, + recovery_timeout=0.1, + success_threshold=2, + error_types=(ConnectionError,), + ) + + # Store a value first (this will succeed after 3 failures) + await memory_store.put(collection="test", key="test", value={"test": "value"}) + + # Open the circuit with 3 failures + for _ in range(3): + with pytest.raises(ConnectionError): + await circuit_breaker.get(collection="test", key="test") + + assert circuit_breaker._state == CircuitState.OPEN # pyright: ignore[reportPrivateUsage] + + # Wait for recovery timeout + import asyncio + + await asyncio.sleep(0.15) + + # Reset the failing store so next attempts succeed + failing_store.failures_before_success = 0 + failing_store.reset_attempts() + + # First success in half-open + result = await circuit_breaker.get(collection="test", key="test") + assert result == {"test": "value"} + assert circuit_breaker._state == CircuitState.HALF_OPEN # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._success_count == 1 # pyright: ignore[reportPrivateUsage] + + # Second success should close the circuit + result = await circuit_breaker.get(collection="test", key="test") + assert result == {"test": "value"} + assert circuit_breaker._state == CircuitState.CLOSED # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._failure_count == 0 # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._success_count == 0 # pyright: ignore[reportPrivateUsage] + + async def test_circuit_resets_failure_count_on_success(self): + failing_store = IntermittentlyFailingStore(failures_before_success=2) + circuit_breaker = CircuitBreakerWrapper( + key_value=failing_store, failure_threshold=5, recovery_timeout=1.0, error_types=(ConnectionError,) + ) + + await failing_store.put(collection="test", key="test", value={"test": "value"}) + + # 2 failures + for _ in range(2): + failing_store.reset_attempts() + with pytest.raises(ConnectionError): + await circuit_breaker.get(collection="test", key="test") + + assert circuit_breaker._failure_count == 2 # pyright: ignore[reportPrivateUsage] + + # Success should reset failure count + failing_store.failures_before_success = 0 + failing_store.reset_attempts() + result = await circuit_breaker.get(collection="test", key="test") + assert result == {"test": "value"} + assert circuit_breaker._failure_count == 0 # pyright: ignore[reportPrivateUsage] + + async def test_circuit_only_counts_specified_error_types(self, memory_store: MemoryStore): + class CustomError(Exception): + pass + + class CustomFailingStore(MemoryStore): + def __init__(self): + super().__init__() + self.call_count = 0 + + async def get(self, key: str, *, collection: str | None = None): + self.call_count += 1 + msg = "Custom error" + raise CustomError(msg) + + failing_store = CustomFailingStore() + circuit_breaker = CircuitBreakerWrapper(key_value=failing_store, failure_threshold=3, error_types=(ConnectionError, TimeoutError)) + + # CustomError is not in error_types, so it should not count toward failures + for _ in range(5): + with pytest.raises(CustomError): + await circuit_breaker.get(collection="test", key="test") + + # Circuit should still be closed + assert circuit_breaker._state == CircuitState.CLOSED # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._failure_count == 0 # pyright: ignore[reportPrivateUsage] + + async def test_circuit_all_operations_tracked(self, memory_store: MemoryStore): + failing_store = IntermittentlyFailingStore(failures_before_success=10) + circuit_breaker = CircuitBreakerWrapper( + key_value=failing_store, failure_threshold=2, recovery_timeout=1.0, error_types=(ConnectionError,) + ) + + # Test that different operations all count toward circuit breaker state + with pytest.raises(ConnectionError): + await circuit_breaker.put(collection="test", key="test", value={"test": "value"}) + + assert circuit_breaker._failure_count == 1 # pyright: ignore[reportPrivateUsage] + + with pytest.raises(ConnectionError): + await circuit_breaker.delete(collection="test", key="test") + + assert circuit_breaker._failure_count == 2 # pyright: ignore[reportPrivateUsage] + assert circuit_breaker._state == CircuitState.OPEN # pyright: ignore[reportPrivateUsage] diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_rate_limit.py b/key-value/key-value-aio/tests/stores/wrappers/test_rate_limit.py new file mode 100644 index 00000000..a1032874 --- /dev/null +++ b/key-value/key-value-aio/tests/stores/wrappers/test_rate_limit.py @@ -0,0 +1,132 @@ +import asyncio + +import pytest +from key_value.shared.errors.wrappers.rate_limit import RateLimitExceededError +from typing_extensions import override + +from key_value.aio.stores.memory.store import MemoryStore +from key_value.aio.wrappers.rate_limit import RateLimitWrapper +from tests.stores.base import BaseStoreTests + + +class TestRateLimitWrapper(BaseStoreTests): + @override + @pytest.fixture + async def store(self, memory_store: MemoryStore) -> RateLimitWrapper: + return RateLimitWrapper(key_value=memory_store, max_requests=100, window_seconds=60.0) + + async def test_rate_limit_allows_requests_within_limit(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=10, window_seconds=1.0) + + # Should allow up to 10 requests + for i in range(10): + await rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) + + # Should not raise any errors + assert True + + async def test_rate_limit_blocks_requests_exceeding_limit_sliding(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=5, window_seconds=1.0, strategy="sliding") + + # Make 5 requests (at the limit) + for i in range(5): + await rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) + + # 6th request should be blocked + with pytest.raises(RateLimitExceededError) as exc_info: + await rate_limiter.put(collection="test", key="key6", value={"value": 6}) + + assert exc_info.value.extra_info is not None + assert exc_info.value.extra_info["max_requests"] == 5 + assert exc_info.value.extra_info["current_requests"] == 5 + + async def test_rate_limit_blocks_requests_exceeding_limit_fixed(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=5, window_seconds=1.0, strategy="fixed") + + # Make 5 requests (at the limit) + for i in range(5): + await rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) + + # 6th request should be blocked + with pytest.raises(RateLimitExceededError): + await rate_limiter.put(collection="test", key="key6", value={"value": 6}) + + async def test_rate_limit_resets_after_window_sliding(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=5, window_seconds=0.1, strategy="sliding") + + # Make 5 requests + for i in range(5): + await rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) + + # Wait for window to expire + await asyncio.sleep(0.15) + + # Should be able to make more requests + await rate_limiter.put(collection="test", key="key6", value={"value": 6}) + assert True + + async def test_rate_limit_resets_after_window_fixed(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=5, window_seconds=0.1, strategy="fixed") + + # Make 5 requests + for i in range(5): + await rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) + + # Wait for window to expire + await asyncio.sleep(0.15) + + # Should be able to make more requests + await rate_limiter.put(collection="test", key="key6", value={"value": 6}) + assert True + + async def test_rate_limit_sliding_window_partial_reset(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=5, window_seconds=0.2, strategy="sliding") + + # Make 5 requests + for i in range(5): + await rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) + + # 6th request should fail + with pytest.raises(RateLimitExceededError): + await rate_limiter.put(collection="test", key="key6", value={"value": 6}) + + # Wait for part of the window to expire (oldest request drops off) + await asyncio.sleep(0.1) + + # Should be able to make one more request (one old request dropped off) + await rate_limiter.put(collection="test", key="key6", value={"value": 6}) + + async def test_rate_limit_applies_to_all_operations(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=3, window_seconds=1.0) + + # Mix different operations + await rate_limiter.put(collection="test", key="key1", value={"value": 1}) + await rate_limiter.get(collection="test", key="key1") + await rate_limiter.delete(collection="test", key="key1") + + # 4th operation should be blocked + with pytest.raises(RateLimitExceededError): + await rate_limiter.put(collection="test", key="key2", value={"value": 2}) + + async def test_rate_limit_concurrent_requests(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store, max_requests=10, window_seconds=1.0) + + # Create 15 concurrent requests + tasks = [rate_limiter.put(collection="test", key=f"key{i}", value={"value": i}) for i in range(15)] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Should have exactly 10 successes and 5 RateLimitExceededErrors + successes = sum(1 for r in results if r is None) + errors = sum(1 for r in results if isinstance(r, RateLimitExceededError)) + + assert successes == 10 + assert errors == 5 + + async def test_rate_limit_default_parameters(self, memory_store: MemoryStore): + rate_limiter = RateLimitWrapper(key_value=memory_store) + + # Should use defaults: 100 requests per 60 seconds + assert rate_limiter.max_requests == 100 + assert rate_limiter.window_seconds == 60.0 + assert rate_limiter.strategy == "sliding" diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_versioning.py b/key-value/key-value-aio/tests/stores/wrappers/test_versioning.py new file mode 100644 index 00000000..18304181 --- /dev/null +++ b/key-value/key-value-aio/tests/stores/wrappers/test_versioning.py @@ -0,0 +1,180 @@ +import pytest +from typing_extensions import override + +from key_value.aio.stores.memory.store import MemoryStore +from key_value.aio.wrappers.versioning import VersioningWrapper +from tests.stores.base import BaseStoreTests + + +class TestVersioningWrapper(BaseStoreTests): + @override + @pytest.fixture + async def store(self, memory_store: MemoryStore) -> VersioningWrapper: + return VersioningWrapper(key_value=memory_store, version="v1") + + async def test_versioning_wraps_and_unwraps_value(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put a value + await versioned_store.put(collection="test", key="test", value={"data": "value"}) + + # Get it back + result = await versioned_store.get(collection="test", key="test") + assert result == {"data": "value"} + + async def test_versioning_stores_version_metadata(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put a value through versioned wrapper + await versioned_store.put(collection="test", key="test", value={"data": "value"}) + + # Check raw value in underlying store + raw_value = await memory_store.get(collection="test", key="test") + assert raw_value is not None + assert "__version__" in raw_value + assert raw_value["__version__"] == "v1" + assert "__versioned_data__" in raw_value + assert raw_value["__versioned_data__"] == {"data": "value"} + + async def test_versioning_returns_none_for_version_mismatch(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store with v1 + await store_v1.put(collection="test", key="test", value={"data": "value"}) + + # Try to retrieve with v2 + result = await store_v2.get(collection="test", key="test") + assert result is None # Version mismatch should return None + + async def test_versioning_handles_unversioned_data(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put unversioned data directly in underlying store + await memory_store.put(collection="test", key="test", value={"data": "value"}) + + # Should return the data as-is (backward compatibility) + result = await versioned_store.get(collection="test", key="test") + assert result == {"data": "value"} + + async def test_versioning_with_integer_version(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version=1) + store_v2 = VersioningWrapper(key_value=memory_store, version=2) + + # Store with version 1 + await store_v1.put(collection="test", key="test", value={"data": "value"}) + + # Retrieve with version 1 + result = await store_v1.get(collection="test", key="test") + assert result == {"data": "value"} + + # Should fail with version 2 + result = await store_v2.get(collection="test", key="test") + assert result is None + + async def test_versioning_get_many(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store some values with v1 + await store_v1.put(collection="test", key="key1", value={"data": "value1"}) + await store_v1.put(collection="test", key="key2", value={"data": "value2"}) + + # Store some values with v2 + await store_v2.put(collection="test", key="key3", value={"data": "value3"}) + + # Get all keys with v1 wrapper + results = await store_v1.get_many(collection="test", keys=["key1", "key2", "key3"]) + + # Should get v1 values, but None for v2 value + assert results[0] == {"data": "value1"} + assert results[1] == {"data": "value2"} + assert results[2] is None # Version mismatch + + async def test_versioning_ttl(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store with TTL + await store_v1.put(collection="test", key="test", value={"data": "value"}, ttl=60.0) + + # Get with matching version + value, ttl = await store_v1.ttl(collection="test", key="test") + assert value == {"data": "value"} + assert ttl is not None + assert ttl > 0 + + # Get with mismatched version + value, ttl = await store_v2.ttl(collection="test", key="test") + assert value is None + assert ttl is None # TTL should also be None for version mismatch + + async def test_versioning_ttl_many(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store values with different versions + await store_v1.put(collection="test", key="key1", value={"data": "value1"}, ttl=60.0) + await store_v2.put(collection="test", key="key2", value={"data": "value2"}, ttl=60.0) + + # Get with v1 wrapper + results = await store_v1.ttl_many(collection="test", keys=["key1", "key2"]) + + # First should have value and TTL, second should be None/None + assert results[0][0] == {"data": "value1"} + assert results[0][1] is not None + assert results[1][0] is None + assert results[1][1] is None + + async def test_versioning_put_many(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put multiple values + await versioned_store.put_many( + collection="test", keys=["key1", "key2", "key3"], values=[{"data": "value1"}, {"data": "value2"}, {"data": "value3"}] + ) + + # Verify all are versioned + for i in range(1, 4): + raw_value = await memory_store.get(collection="test", key=f"key{i}") + assert raw_value is not None + assert raw_value["__version__"] == "v1" + assert raw_value["__versioned_data__"] == {"data": f"value{i}"} + + async def test_versioning_doesnt_double_wrap(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put a value that already has version metadata + await versioned_store.put(collection="test", key="test", value={"__version__": "v1", "__versioned_data__": {"data": "value"}}) + + # Check it wasn't double-wrapped + raw_value = await memory_store.get(collection="test", key="test") + assert raw_value is not None + assert raw_value == {"__version__": "v1", "__versioned_data__": {"data": "value"}} + # Should not have nested version keys + assert "__versioned_data__" in raw_value + assert "__version__" not in raw_value.get("__versioned_data__", {}) + + async def test_versioning_schema_evolution_scenario(self, memory_store: MemoryStore): + """Test a realistic schema evolution scenario.""" + # Application v1: Store user with name and email + app_v1 = VersioningWrapper(key_value=memory_store, version="user_schema_v1") + await app_v1.put(collection="users", key="user:123", value={"name": "John Doe", "email": "john@example.com"}) + + # Application v2: Expects users to have name, email, and age + app_v2 = VersioningWrapper(key_value=memory_store, version="user_schema_v2") + + # When v2 tries to read old data, it gets None (cache miss) + result = await app_v2.get(collection="users", key="user:123") + assert result is None + + # Application can then reload from authoritative source with new schema + await app_v2.put(collection="users", key="user:123", value={"name": "John Doe", "email": "john@example.com", "age": 30}) + + # Now v2 can read it + result = await app_v2.get(collection="users", key="user:123") + assert result == {"name": "John Doe", "email": "john@example.com", "age": 30} + + # But v1 still gets None (cache invalidation works both ways) + result = await app_v1.get(collection="users", key="user:123") + assert result is None diff --git a/key-value/key-value-shared/src/key_value/shared/errors/wrappers/bulkhead.py b/key-value/key-value-shared/src/key_value/shared/errors/wrappers/bulkhead.py new file mode 100644 index 00000000..f39af077 --- /dev/null +++ b/key-value/key-value-shared/src/key_value/shared/errors/wrappers/bulkhead.py @@ -0,0 +1,11 @@ +from key_value.shared.errors.key_value import KeyValueOperationError + + +class BulkheadFullError(KeyValueOperationError): + """Raised when the bulkhead is full and cannot accept more concurrent operations.""" + + def __init__(self, max_concurrent: int, max_waiting: int): + super().__init__( + message="Bulkhead is full. Maximum concurrent operations and waiting queue limit reached.", + extra_info={"max_concurrent": max_concurrent, "max_waiting": max_waiting}, + ) diff --git a/key-value/key-value-shared/src/key_value/shared/errors/wrappers/circuit_breaker.py b/key-value/key-value-shared/src/key_value/shared/errors/wrappers/circuit_breaker.py new file mode 100644 index 00000000..2914b543 --- /dev/null +++ b/key-value/key-value-shared/src/key_value/shared/errors/wrappers/circuit_breaker.py @@ -0,0 +1,11 @@ +from key_value.shared.errors.key_value import KeyValueOperationError + + +class CircuitOpenError(KeyValueOperationError): + """Raised when the circuit breaker is open and requests are blocked.""" + + def __init__(self, failure_count: int, last_failure_time: float | None = None): + super().__init__( + message="Circuit breaker is open. Requests are temporarily blocked due to consecutive failures.", + extra_info={"failure_count": failure_count, "last_failure_time": last_failure_time}, + ) diff --git a/key-value/key-value-shared/src/key_value/shared/errors/wrappers/rate_limit.py b/key-value/key-value-shared/src/key_value/shared/errors/wrappers/rate_limit.py new file mode 100644 index 00000000..f69741b4 --- /dev/null +++ b/key-value/key-value-shared/src/key_value/shared/errors/wrappers/rate_limit.py @@ -0,0 +1,11 @@ +from key_value.shared.errors.key_value import KeyValueOperationError + + +class RateLimitExceededError(KeyValueOperationError): + """Raised when the rate limit has been exceeded.""" + + def __init__(self, current_requests: int, max_requests: int, window_seconds: float): + super().__init__( + message="Rate limit exceeded. Too many requests within the time window.", + extra_info={"current_requests": current_requests, "max_requests": max_requests, "window_seconds": window_seconds}, + ) From fbb54ea75ab60781023d1a244a785b6804bcbb06 Mon Sep 17 00:00:00 2001 From: William Easton Date: Sun, 9 Nov 2025 13:45:08 -0600 Subject: [PATCH 2/3] Codegen --- .../code_gen/wrappers/versioning/__init__.py | 6 + .../code_gen/wrappers/versioning/wrapper.py | 118 +++++++++++ .../sync/wrappers/versioning/__init__.py | 6 + .../stores/wrappers/test_versioning.py | 183 ++++++++++++++++++ scripts/build_sync_library.py | 7 +- 5 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/__init__.py create mode 100644 key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py create mode 100644 key-value/key-value-sync/src/key_value/sync/wrappers/versioning/__init__.py create mode 100644 key-value/key-value-sync/tests/code_gen/stores/wrappers/test_versioning.py diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/__init__.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/__init__.py new file mode 100644 index 00000000..f7ee25cf --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/__init__.py @@ -0,0 +1,6 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +from key_value.sync.code_gen.wrappers.versioning.wrapper import VersioningWrapper + +__all__ = ["VersioningWrapper"] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py new file mode 100644 index 00000000..9324a07a --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py @@ -0,0 +1,118 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'wrapper.py' +# DO NOT CHANGE! Change the original file instead. +from collections.abc import Mapping, Sequence +from typing import Any, SupportsFloat + +from typing_extensions import override + +from key_value.sync.code_gen.protocols.key_value import KeyValue +from key_value.sync.code_gen.wrappers.base import BaseWrapper + +# Special keys used to store version information +_VERSION_KEY = "__version__" +_VERSIONED_DATA_KEY = "__versioned_data__" + + +class VersioningWrapper(BaseWrapper): + """Wrapper that adds version tagging to values for schema evolution and cache invalidation. + + This wrapper automatically tags all stored values with a version identifier. When retrieving + values, it checks the version and returns None for values with mismatched versions, effectively + auto-invalidating old cache entries. + + This is useful for: + - Schema evolution: When your data structure changes, old cached values are automatically invalidated + - Deployment coordination: Different versions of your application can coexist without sharing incompatible cached data + - Safe cache invalidation: Increment the version to invalidate all cached entries without manual cleanup + + The versioned format looks like: + { + "__version__": "v1.2.0", + "__versioned_data__": { + "actual": "user", + "data": "here" + } + } + + Example: + # Version 1 of your application + store_v1 = VersioningWrapper(key_value=store, version="v1") + await store_v1.put(key="user:123", value={"name": "John", "email": "john@example.com"}) + + # Version 2 changes the schema (adds "age" field) + store_v2 = VersioningWrapper(key_value=store, version="v2") + result = await store_v2.get(key="user:123") + # Returns None because version mismatch, forcing reload with new schema + """ + + def __init__(self, key_value: KeyValue, version: str | int) -> None: + """Initialize the versioning wrapper. + + Args: + key_value: The store to wrap. + version: The version identifier to tag values with. Can be string (e.g., "v1.2.0") or int (e.g., 1). + """ + self.key_value: KeyValue = key_value + self.version: str | int = version + + super().__init__() + + def _wrap_value(self, value: dict[str, Any]) -> dict[str, Any]: + """Wrap a value with version information.""" + # If already versioned, don't double-wrap + if _VERSION_KEY in value: + return value + + return {_VERSION_KEY: self.version, _VERSIONED_DATA_KEY: value} + + def _unwrap_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None: + """Unwrap a versioned value, returning None if version mismatch.""" + if value is None: + return None + + # Not versioned, return as-is + if _VERSION_KEY not in value: + return value + + # Check version match + if value[_VERSION_KEY] != self.version: + # Version mismatch - auto-invalidate by returning None + return None + + # Extract the actual data + return value.get(_VERSIONED_DATA_KEY, value) + + @override + def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + value = self.key_value.get(key=key, collection=collection) + return self._unwrap_value(value) + + @override + def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + values = self.key_value.get_many(keys=keys, collection=collection) + return [self._unwrap_value(value) for value in values] + + @override + def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + (value, ttl) = self.key_value.ttl(key=key, collection=collection) + unwrapped = self._unwrap_value(value) + # If version mismatch, return None for TTL as well + return (unwrapped, ttl if unwrapped is not None else None) + + @override + def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + results = self.key_value.ttl_many(keys=keys, collection=collection) + return [(self._unwrap_value(value), ttl if self._unwrap_value(value) is not None else None) for (value, ttl) in results] + + @override + def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: + wrapped_value = self._wrap_value(dict(value)) + return self.key_value.put(key=key, value=wrapped_value, collection=collection, ttl=ttl) + + @override + def put_many( + self, keys: Sequence[str], values: Sequence[Mapping[str, Any]], *, collection: str | None = None, ttl: SupportsFloat | None = None + ) -> None: + wrapped_values = [self._wrap_value(dict(value)) for value in values] + return self.key_value.put_many(keys=keys, values=wrapped_values, collection=collection, ttl=ttl) diff --git a/key-value/key-value-sync/src/key_value/sync/wrappers/versioning/__init__.py b/key-value/key-value-sync/src/key_value/sync/wrappers/versioning/__init__.py new file mode 100644 index 00000000..f7ee25cf --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/wrappers/versioning/__init__.py @@ -0,0 +1,6 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +from key_value.sync.code_gen.wrappers.versioning.wrapper import VersioningWrapper + +__all__ = ["VersioningWrapper"] diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_versioning.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_versioning.py new file mode 100644 index 00000000..9dcd3512 --- /dev/null +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_versioning.py @@ -0,0 +1,183 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'test_versioning.py' +# DO NOT CHANGE! Change the original file instead. +import pytest +from typing_extensions import override + +from key_value.sync.code_gen.stores.memory.store import MemoryStore +from key_value.sync.code_gen.wrappers.versioning import VersioningWrapper +from tests.code_gen.stores.base import BaseStoreTests + + +class TestVersioningWrapper(BaseStoreTests): + @override + @pytest.fixture + def store(self, memory_store: MemoryStore) -> VersioningWrapper: + return VersioningWrapper(key_value=memory_store, version="v1") + + def test_versioning_wraps_and_unwraps_value(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put a value + versioned_store.put(collection="test", key="test", value={"data": "value"}) + + # Get it back + result = versioned_store.get(collection="test", key="test") + assert result == {"data": "value"} + + def test_versioning_stores_version_metadata(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put a value through versioned wrapper + versioned_store.put(collection="test", key="test", value={"data": "value"}) + + # Check raw value in underlying store + raw_value = memory_store.get(collection="test", key="test") + assert raw_value is not None + assert "__version__" in raw_value + assert raw_value["__version__"] == "v1" + assert "__versioned_data__" in raw_value + assert raw_value["__versioned_data__"] == {"data": "value"} + + def test_versioning_returns_none_for_version_mismatch(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store with v1 + store_v1.put(collection="test", key="test", value={"data": "value"}) + + # Try to retrieve with v2 + result = store_v2.get(collection="test", key="test") + assert result is None # Version mismatch should return None + + def test_versioning_handles_unversioned_data(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put unversioned data directly in underlying store + memory_store.put(collection="test", key="test", value={"data": "value"}) + + # Should return the data as-is (backward compatibility) + result = versioned_store.get(collection="test", key="test") + assert result == {"data": "value"} + + def test_versioning_with_integer_version(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version=1) + store_v2 = VersioningWrapper(key_value=memory_store, version=2) + + # Store with version 1 + store_v1.put(collection="test", key="test", value={"data": "value"}) + + # Retrieve with version 1 + result = store_v1.get(collection="test", key="test") + assert result == {"data": "value"} + + # Should fail with version 2 + result = store_v2.get(collection="test", key="test") + assert result is None + + def test_versioning_get_many(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store some values with v1 + store_v1.put(collection="test", key="key1", value={"data": "value1"}) + store_v1.put(collection="test", key="key2", value={"data": "value2"}) + + # Store some values with v2 + store_v2.put(collection="test", key="key3", value={"data": "value3"}) + + # Get all keys with v1 wrapper + results = store_v1.get_many(collection="test", keys=["key1", "key2", "key3"]) + + # Should get v1 values, but None for v2 value + assert results[0] == {"data": "value1"} + assert results[1] == {"data": "value2"} + assert results[2] is None # Version mismatch + + def test_versioning_ttl(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store with TTL + store_v1.put(collection="test", key="test", value={"data": "value"}, ttl=60.0) + + # Get with matching version + (value, ttl) = store_v1.ttl(collection="test", key="test") + assert value == {"data": "value"} + assert ttl is not None + assert ttl > 0 + + # Get with mismatched version + (value, ttl) = store_v2.ttl(collection="test", key="test") + assert value is None + assert ttl is None # TTL should also be None for version mismatch + + def test_versioning_ttl_many(self, memory_store: MemoryStore): + store_v1 = VersioningWrapper(key_value=memory_store, version="v1") + store_v2 = VersioningWrapper(key_value=memory_store, version="v2") + + # Store values with different versions + store_v1.put(collection="test", key="key1", value={"data": "value1"}, ttl=60.0) + store_v2.put(collection="test", key="key2", value={"data": "value2"}, ttl=60.0) + + # Get with v1 wrapper + results = store_v1.ttl_many(collection="test", keys=["key1", "key2"]) + + # First should have value and TTL, second should be None/None + assert results[0][0] == {"data": "value1"} + assert results[0][1] is not None + assert results[1][0] is None + assert results[1][1] is None + + def test_versioning_put_many(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put multiple values + versioned_store.put_many( + collection="test", keys=["key1", "key2", "key3"], values=[{"data": "value1"}, {"data": "value2"}, {"data": "value3"}] + ) + + # Verify all are versioned + for i in range(1, 4): + raw_value = memory_store.get(collection="test", key=f"key{i}") + assert raw_value is not None + assert raw_value["__version__"] == "v1" + assert raw_value["__versioned_data__"] == {"data": f"value{i}"} + + def test_versioning_doesnt_double_wrap(self, memory_store: MemoryStore): + versioned_store = VersioningWrapper(key_value=memory_store, version="v1") + + # Put a value that already has version metadata + versioned_store.put(collection="test", key="test", value={"__version__": "v1", "__versioned_data__": {"data": "value"}}) + + # Check it wasn't double-wrapped + raw_value = memory_store.get(collection="test", key="test") + assert raw_value is not None + assert raw_value == {"__version__": "v1", "__versioned_data__": {"data": "value"}} + # Should not have nested version keys + assert "__versioned_data__" in raw_value + assert "__version__" not in raw_value.get("__versioned_data__", {}) + + def test_versioning_schema_evolution_scenario(self, memory_store: MemoryStore): + """Test a realistic schema evolution scenario.""" + # Application v1: Store user with name and email + app_v1 = VersioningWrapper(key_value=memory_store, version="user_schema_v1") + app_v1.put(collection="users", key="user:123", value={"name": "John Doe", "email": "john@example.com"}) + + # Application v2: Expects users to have name, email, and age + app_v2 = VersioningWrapper(key_value=memory_store, version="user_schema_v2") + + # When v2 tries to read old data, it gets None (cache miss) + result = app_v2.get(collection="users", key="user:123") + assert result is None + + # Application can then reload from authoritative source with new schema + app_v2.put(collection="users", key="user:123", value={"name": "John Doe", "email": "john@example.com", "age": 30}) + + # Now v2 can read it + result = app_v2.get(collection="users", key="user:123") + assert result == {"name": "John Doe", "email": "john@example.com", "age": 30} + + # But v1 still gets None (cache invalidation works both ways) + result = app_v1.get(collection="users", key="user:123") + assert result is None diff --git a/scripts/build_sync_library.py b/scripts/build_sync_library.py index c79a870b..a0f5ac50 100644 --- a/scripts/build_sync_library.py +++ b/scripts/build_sync_library.py @@ -50,6 +50,9 @@ EXCLUDE_FILES = [ "key-value/key-value-aio/src/key_value/aio/__init__.py", "key-value/key-value-aio/tests/stores/wrappers/test_timeout.py", + "key-value/key-value-aio/tests/stores/wrappers/test_bulkhead.py", + "key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py", + "key-value/key-value-aio/tests/stores/wrappers/test_rate_limit.py", ] EXCLUDE_DIRECTORIES = [ "key-value/key-value-aio/src/key_value/aio/stores/dynamodb", @@ -57,7 +60,9 @@ "key-value/key-value-aio/src/key_value/aio/stores/memcached", "key-value/key-value-aio/tests/stores/memcached", "key-value/key-value-aio/src/key_value/aio/wrappers/timeout", - "key-value/key-value-aio/tests/wrappers/timeout", + "key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead", + "key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker", + "key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit", ] SCRIPT_NAME = Path(sys.argv[0]).name From a30278d8e0163df26a1829d435a1bc19af3158af Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:18:49 +0000 Subject: [PATCH 3/3] Apply CodeRabbit feedback: monotonic clocks and versioning fixes - Use time.monotonic() in CircuitBreakerWrapper and RateLimitWrapper to prevent issues with system clock adjustments - Fix VersioningWrapper double-wrap bypass by checking both version and data keys - Fix VersioningWrapper malformed data handling to return None instead of leaking metadata - Fix VersioningWrapper ttl_many to unwrap values only once for better performance - Fix circuit breaker test store to fail on all operations (get/put/delete) - Improve BulkheadWrapper waiting count tracking (WIP - tests still failing) - Fix RateLimitWrapper off-by-one error (WIP - tests still failing) Co-authored-by: William Easton --- .../aio/wrappers/bulkhead/wrapper.py | 24 ++++++++------ .../aio/wrappers/circuit_breaker/wrapper.py | 10 +++--- .../aio/wrappers/rate_limit/wrapper.py | 32 +++++++++++-------- .../aio/wrappers/versioning/wrapper.py | 14 +++++--- .../stores/wrappers/test_circuit_breaker.py | 18 +++++++++-- .../code_gen/wrappers/versioning/wrapper.py | 14 +++++--- 6 files changed, 73 insertions(+), 39 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py index 1aa7f730..a9154f83 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py @@ -65,27 +65,33 @@ def __init__( async def _execute_with_bulkhead(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T: """Execute an operation with bulkhead resource limiting.""" - # Check if we can accept this operation + # Check if we're over capacity before even trying + # Count the number currently executing + waiting async with self._waiting_lock: - if self._waiting_count >= self.max_waiting: + # _semaphore._value tells us how many slots are available + # max_concurrent - _value = number currently executing + currently_executing = self.max_concurrent - self._semaphore._value + total_in_system = currently_executing + self._waiting_count + + if total_in_system >= self.max_concurrent + self.max_waiting: raise BulkheadFullError(max_concurrent=self.max_concurrent, max_waiting=self.max_waiting) + + # We're allowed in - increment waiting count self._waiting_count += 1 try: - # Acquire semaphore to limit concurrency + # Acquire semaphore (may block) async with self._semaphore: - # Once we have the semaphore, we're no longer waiting + # Once we have the semaphore, we're executing (not waiting) async with self._waiting_lock: self._waiting_count -= 1 # Execute the operation return await operation(*args, **kwargs) - except Exception: - # Make sure to decrement waiting count if we error before acquiring semaphore + except BaseException: + # Make sure to clean up waiting count if we fail before executing async with self._waiting_lock: - # Only decrement if we're still counted as waiting - # (might have already decremented if we got the semaphore) - if self._waiting_count > 0 and self._semaphore.locked(): + if self._waiting_count > 0: self._waiting_count -= 1 raise diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py index b92a1716..5edb52b6 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/circuit_breaker/wrapper.py @@ -77,15 +77,16 @@ def __init__( self._state: CircuitState = CircuitState.CLOSED self._failure_count: int = 0 self._success_count: int = 0 - self._last_failure_time: float | None = None + self._last_failure_time: float | None = None # Wall clock time for diagnostics + self._last_failure_tick: float | None = None # Monotonic time for timeout calculations super().__init__() def _check_circuit(self) -> None: """Check the circuit state and potentially transition states.""" if self._state == CircuitState.OPEN: - # Check if we should move to half-open - if self._last_failure_time is not None and time.time() - self._last_failure_time >= self.recovery_timeout: + # Check if we should move to half-open (using monotonic time for reliability) + if self._last_failure_tick is not None and time.monotonic() - self._last_failure_tick >= self.recovery_timeout: self._state = CircuitState.HALF_OPEN self._success_count = 0 else: @@ -107,7 +108,8 @@ def _on_success(self) -> None: def _on_failure(self) -> None: """Handle failed operation.""" - self._last_failure_time = time.time() + self._last_failure_time = time.time() # Wall clock for diagnostics + self._last_failure_tick = time.monotonic() # Monotonic time for timeout calculations if self._state == CircuitState.HALF_OPEN: # Failed in half-open, go back to open diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py index 565cde0e..33b114ba 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/rate_limit/wrapper.py @@ -73,40 +73,44 @@ def __init__( async def _check_rate_limit_sliding(self) -> None: """Check rate limit using sliding window strategy.""" async with self._lock: - now = time.time() + now = time.monotonic() # Remove requests outside the current window while self._request_times and self._request_times[0] < now - self.window_seconds: self._request_times.popleft() - # Check if we're at the limit - if len(self._request_times) >= self.max_requests: + # Record this request first + self._request_times.append(now) + + # Check if we exceeded the limit (after adding this request) + if len(self._request_times) > self.max_requests: + # Remove the request we just added since it exceeded the limit + self._request_times.pop() raise RateLimitExceededError( - current_requests=len(self._request_times), max_requests=self.max_requests, window_seconds=self.window_seconds + current_requests=self.max_requests, max_requests=self.max_requests, window_seconds=self.window_seconds ) - # Record this request - self._request_times.append(now) - async def _check_rate_limit_fixed(self) -> None: """Check rate limit using fixed window strategy.""" async with self._lock: - now = time.time() + now = time.monotonic() # Check if we need to start a new window if self._window_start is None or now >= self._window_start + self.window_seconds: self._window_start = now self._request_count = 0 - # Check if we're at the limit - if self._request_count >= self.max_requests: + # Record this request first + self._request_count += 1 + + # Check if we exceeded the limit (after adding this request) + if self._request_count > self.max_requests: + # Decrement since this request exceeds the limit + self._request_count -= 1 raise RateLimitExceededError( - current_requests=self._request_count, max_requests=self.max_requests, window_seconds=self.window_seconds + current_requests=self.max_requests, max_requests=self.max_requests, window_seconds=self.window_seconds ) - # Record this request - self._request_count += 1 - async def _check_rate_limit(self) -> None: """Check rate limit based on configured strategy.""" if self.strategy == "sliding": diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py index c8a5977c..21cc887e 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/versioning/wrapper.py @@ -61,8 +61,8 @@ def __init__( def _wrap_value(self, value: dict[str, Any]) -> dict[str, Any]: """Wrap a value with version information.""" - # If already versioned, don't double-wrap - if _VERSION_KEY in value: + # If already properly versioned, don't double-wrap + if _VERSION_KEY in value and _VERSIONED_DATA_KEY in value: return value return {_VERSION_KEY: self.version, _VERSIONED_DATA_KEY: value} @@ -81,8 +81,11 @@ def _unwrap_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None: # Version mismatch - auto-invalidate by returning None return None - # Extract the actual data - return value.get(_VERSIONED_DATA_KEY, value) + # Extract the actual data (must be present in properly wrapped data) + if _VERSIONED_DATA_KEY not in value: + # Malformed versioned data - treat as corruption + return None + return value[_VERSIONED_DATA_KEY] @override async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: @@ -104,7 +107,8 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[st @override async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: results = await self.key_value.ttl_many(keys=keys, collection=collection) - return [(self._unwrap_value(value), ttl if self._unwrap_value(value) is not None else None) for value, ttl in results] + unwrapped = [(self._unwrap_value(value), ttl) for value, ttl in results] + return [(value, ttl if value is not None else None) for value, ttl in unwrapped] @override async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py b/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py index 4f99f847..95e872ba 100644 --- a/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py +++ b/key-value/key-value-aio/tests/stores/wrappers/test_circuit_breaker.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from key_value.shared.errors.wrappers.circuit_breaker import CircuitOpenError from typing_extensions import override @@ -16,13 +18,25 @@ def __init__(self, failures_before_success: int = 5): self.failures_before_success = failures_before_success self.attempt_count = 0 - async def get(self, key: str, *, collection: str | None = None): + def _check_and_maybe_fail(self): + """Check if we should fail this operation.""" self.attempt_count += 1 if self.attempt_count <= self.failures_before_success: msg = "Simulated connection error" raise ConnectionError(msg) + + async def get(self, key: str, *, collection: str | None = None): + self._check_and_maybe_fail() return await super().get(key=key, collection=collection) + async def put(self, key: str, value: dict[str, Any], *, collection: str | None = None, ttl: float | None = None): + self._check_and_maybe_fail() + return await super().put(key=key, value=value, collection=collection, ttl=ttl) + + async def delete(self, key: str, *, collection: str | None = None): + self._check_and_maybe_fail() + return await super().delete(key=key, collection=collection) + def reset_attempts(self): self.attempt_count = 0 @@ -101,7 +115,7 @@ async def test_circuit_closes_after_successful_recovery(self, memory_store: Memo ) # Store a value first (this will succeed after 3 failures) - await memory_store.put(collection="test", key="test", value={"test": "value"}) + await failing_store.put(collection="test", key="test", value={"test": "value"}) # Open the circuit with 3 failures for _ in range(3): diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py index 9324a07a..16e45dd1 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/versioning/wrapper.py @@ -60,8 +60,8 @@ def __init__(self, key_value: KeyValue, version: str | int) -> None: def _wrap_value(self, value: dict[str, Any]) -> dict[str, Any]: """Wrap a value with version information.""" - # If already versioned, don't double-wrap - if _VERSION_KEY in value: + # If already properly versioned, don't double-wrap + if _VERSION_KEY in value and _VERSIONED_DATA_KEY in value: return value return {_VERSION_KEY: self.version, _VERSIONED_DATA_KEY: value} @@ -80,8 +80,11 @@ def _unwrap_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None: # Version mismatch - auto-invalidate by returning None return None - # Extract the actual data - return value.get(_VERSIONED_DATA_KEY, value) + # Extract the actual data (must be present in properly wrapped data) + if _VERSIONED_DATA_KEY not in value: + # Malformed versioned data - treat as corruption + return None + return value[_VERSIONED_DATA_KEY] @override def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: @@ -103,7 +106,8 @@ def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any @override def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: results = self.key_value.ttl_many(keys=keys, collection=collection) - return [(self._unwrap_value(value), ttl if self._unwrap_value(value) is not None else None) for (value, ttl) in results] + unwrapped = [(self._unwrap_value(value), ttl) for (value, ttl) in results] + return [(value, ttl if value is not None else None) for (value, ttl) in unwrapped] @override def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: