diff --git a/src/utils/web_search/.dockerignore b/src/utils/web_search/.dockerignore index 8cc8613c..ac4f32d4 100644 --- a/src/utils/web_search/.dockerignore +++ b/src/utils/web_search/.dockerignore @@ -3,6 +3,7 @@ !app.py !db.py !auth.py +!daily_usage.py !requirements-app.txt !requirements_app.in !Dockerfile diff --git a/src/utils/web_search/.env.example b/src/utils/web_search/.env.example index b8c232ac..b25766c7 100644 --- a/src/utils/web_search/.env.example +++ b/src/utils/web_search/.env.example @@ -3,8 +3,23 @@ FIRESTORE_PROJECT_ID=*** FIRESTORE_DATABASE_NAME=*** FIRESTORE_COLLECTION=apiKeys -PBKDF2_ITERATIONS=20000 -PBKDF2_SALT_BYTES=16 +GEMINI_MAX_ATTEMPTS=1 +GEMINI_MAX_BACKOFF_SECONDS=2 + +API_KEY_PBKDF2_ITERATIONS=20000 +API_KEY_PBKDF2_SALT_BYTES=16 API_KEY_CACHE_TTL=30 API_KEY_CACHE_MAX_ITEMS=1024 + +API_KEY_USAGE_MAX_RETRIES=8 +API_KEY_USAGE_BASE_DELAY=0.05 +API_KEY_USAGE_MAX_DELAY=1.0 + +DAILY_USAGE_COLLECTION=dailyUsageCounters +DAILY_USAGE_MAX_RETRIES=8 +DAILY_USAGE_BASE_DELAY=0.05 +DAILY_USAGE_MAX_DELAY=1.0 + +GEMINI_GROUNDING_FREE_LIMIT_PRO=1500 +GEMINI_GROUNDING_FREE_LIMIT_FLASH=1500 diff --git a/src/utils/web_search/Dockerfile b/src/utils/web_search/Dockerfile index a77935ba..5f3e4faf 100644 --- a/src/utils/web_search/Dockerfile +++ b/src/utils/web_search/Dockerfile @@ -7,7 +7,7 @@ RUN pip install --no-cache-dir -r requirements-app.txt RUN mkdir -p /app/src/utils/web_search RUN touch /app/src/utils/__init__.py -COPY __init__.py app.py auth.py db.py /app/src/utils/web_search/ +COPY __init__.py app.py auth.py db.py daily_usage.py /app/src/utils/web_search/ ENV PYTHONPATH=/app/src CMD ["uvicorn", "utils.web_search.app:app", "--host", "0.0.0.0", "--port", "8080"] diff --git a/src/utils/web_search/README.md b/src/utils/web_search/README.md index fa403862..90bb5e8d 100644 --- a/src/utils/web_search/README.md +++ b/src/utils/web_search/README.md @@ -49,6 +49,10 @@ gcloud auth configure-docker "$REGION-docker.pkg.dev" | `GEMINI_API_KEY` | Gemini API key used by the proxy | _(required)_ | | `GEMINI_MAX_ATTEMPTS`, `GEMINI_MAX_BACKOFF_SECONDS` | Retry tuning | `5`, `10` | | `API_KEY_CACHE_TTL`, `API_KEY_CACHE_MAX_ITEMS` | Auth cache tuning | `30`, `1024` | +| `DAILY_USAGE_COLLECTION` | Collection that stores per-day usage counters | `dailyUsageCounters` | +| `DAILY_USAGE_MAX_RETRIES`, `DAILY_USAGE_BASE_DELAY`, `DAILY_USAGE_MAX_DELAY` | Daily usage retry tuning | `8`, `0.05`, `1.0` | +| `GEMINI_GROUNDING_FREE_LIMIT_PRO` | Daily free allowance for `gemini-2.5-pro` | `1500` | +| `GEMINI_GROUNDING_FREE_LIMIT_FLASH` | Shared daily free allowance for Flash/Flash-Lite | `1500` | Keep `.env.example` up to date so teammates can copy it into their own `.env`. @@ -81,6 +85,8 @@ Keep `.env.example` up to date so teammates can copy it into their own `.env`. export FIRESTORE_DATABASE_NAME=grounding export FIRESTORE_EMULATOR_HOST=0.0.0.0:8922 export GEMINI_API_KEY="dev-placeholder" + export GEMINI_GROUNDING_FREE_LIMIT_PRO=1500 + export GEMINI_GROUNDING_FREE_LIMIT_FLASH=1500 ``` 4. **Install Python dependencies** diff --git a/src/utils/web_search/app.py b/src/utils/web_search/app.py index 63c83162..a6b8cdfa 100644 --- a/src/utils/web_search/app.py +++ b/src/utils/web_search/app.py @@ -21,6 +21,7 @@ InactiveAPIKeyError, InvalidAPIKeyError, ) +from .daily_usage import DailyUsageRepository from .db import APIKeyRecord, APIKeyRepository, UsageLimitExceededError @@ -37,6 +38,58 @@ FIRESTORE_COLLECTION = os.getenv("FIRESTORE_COLLECTION", "apiKeys") API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "30")) API_KEY_CACHE_MAX_ITEMS = int(os.getenv("API_KEY_CACHE_MAX_ITEMS", "1024")) +FREE_LIMIT_DEFAULT_PRO = 1500 +FREE_LIMIT_DEFAULT_FLASH = 1500 + + +def _parse_free_limit(env_var: str, default: int) -> int: + """Parse a non-negative integer from the environment with logging.""" + value = os.getenv(env_var) + if value is None or value == "": + return default + try: + parsed = int(value) + except ValueError: + logger.warning( + "Invalid value '%s' for %s; falling back to %d", + value, + env_var, + default, + ) + return default + if parsed < 0: + logger.warning( + "Negative value '%s' for %s; treating as 0", + value, + env_var, + ) + return 0 + return parsed + + +MODEL_TO_USAGE_BUCKET: dict[str, str] = { + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-2.5-flash": "gemini-2.5-flash-family", + "gemini-2.5-flash-lite": "gemini-2.5-flash-family", +} + +BUCKET_FREE_LIMITS: dict[str, int] = { + "gemini-2.5-pro": _parse_free_limit( + "GEMINI_GROUNDING_FREE_LIMIT_PRO", + FREE_LIMIT_DEFAULT_PRO, + ), + "gemini-2.5-flash-family": _parse_free_limit( + "GEMINI_GROUNDING_FREE_LIMIT_FLASH", + FREE_LIMIT_DEFAULT_FLASH, + ), +} + + +def _resolve_usage_bucket(model: str) -> tuple[str, int]: + """Return the usage bucket and free allowance for the given model.""" + bucket = MODEL_TO_USAGE_BUCKET.get(model, model) + return bucket, BUCKET_FREE_LIMITS.get(bucket, 0) + RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ( google_exceptions.ResourceExhausted, @@ -189,6 +242,13 @@ async def startup_event() -> None: cache_ttl_seconds=API_KEY_CACHE_TTL, cache_max_items=API_KEY_CACHE_MAX_ITEMS, ) + app.state.daily_usage_repository = DailyUsageRepository( + firestore_client, + collection_name=os.getenv( + "DAILY_USAGE_COLLECTION", + "dailyUsageCounters", + ), + ) async def shutdown_event() -> None: @@ -235,6 +295,18 @@ def get_authenticator() -> APIKeyAuthenticator: return authenticator +def get_daily_usage_repository() -> DailyUsageRepository: + """Return the daily usage repository stored on the app state.""" + repository: DailyUsageRepository | None = getattr( + app.state, + "daily_usage_repository", + None, + ) + if repository is None: + raise RuntimeError("Daily usage repository has not been initialised") + return repository + + async def _authenticate_request( api_key_header: str, authenticator: APIKeyAuthenticator, @@ -269,37 +341,6 @@ async def _authenticate_request( ) from exc -async def require_api_key( - api_key_header: Annotated[str, Header(alias="X-API-Key")], - authenticator: Annotated[APIKeyAuthenticator, Depends(get_authenticator)], -) -> APIKeyRecord: - """Validate the user's API key and reserve a usage slot. - - Parameters - ---------- - api_key_header : str - API key supplied in the ``X-API-Key`` header. - authenticator : APIKeyAuthenticator - Authenticator responsible for validating and reserving usage. - - Returns - ------- - APIKeyRecord - Updated API key record that includes the latest usage counter. - - Raises - ------ - HTTPException - Raised when the API key is invalid, inactive, or has exhausted its - quota. - """ - return await _authenticate_request( - api_key_header, - authenticator, - consume_usage=True, - ) - - async def require_api_key_without_consumption( api_key_header: Annotated[str, Header(alias="X-API-Key")], authenticator: Annotated[APIKeyAuthenticator, Depends(get_authenticator)], @@ -405,7 +446,9 @@ async def call_gemini_with_retry(request: RequestBody) -> types.GenerateContentR ) except RETRYABLE_EXCEPTIONS as exc: if attempt >= MAX_GEMINI_ATTEMPTS: - logger.exception("Gemini request failed after retries") + logger.exception( + "Gemini request failed after %d retries", MAX_GEMINI_ATTEMPTS + ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="Gemini is currently unavailable", @@ -444,7 +487,15 @@ async def health() -> dict[str, str]: @router.post("/v1/grounding_with_search") async def search( request: RequestBody, - _: Annotated[APIKeyRecord, Depends(require_api_key)], + record: Annotated[ + APIKeyRecord, + Depends(require_api_key_without_consumption), + ], + authenticator: Annotated[APIKeyAuthenticator, Depends(get_authenticator)], + daily_usage: Annotated[ + DailyUsageRepository, + Depends(get_daily_usage_repository), + ], ) -> dict[str, object]: """Proxy Gemini grounding requests with quota enforcement. @@ -452,17 +503,77 @@ async def search( ---------- request : RequestBody Payload describing the Gemini call. - _ : APIKeyRecord - API key record produced by ``require_api_key``. The underscore keeps - the dependency explicit without exposing it to callers. + record : APIKeyRecord + API key record produced by ``require_api_key``. + authenticator : APIKeyAuthenticator + Authenticator dependency used to roll back usage reservations on error. Returns ------- - google.genai.types.GenerateContentResponse - Response returned by the Gemini model. + dict of str to object + JSON serialisable response returned by the Gemini model. """ - response = await call_gemini_with_retry(request) - logger.info("Gemini request completed for model %s", request.model) + bucket, free_limit = _resolve_usage_bucket(request.model) + consumed_api_quota = False + reservation = await daily_usage.reserve(bucket, free_limit) + + if not reservation.consumed_free: + try: + updated_record = await authenticator.consume_usage(record.lookup_hash) + except UsageLimitExceededError as exc: + await daily_usage.release(reservation) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API key usage limit exceeded", + ) from exc + except InvalidAPIKeyError as exc: + await daily_usage.release(reservation) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key provided", + ) from exc + except InactiveAPIKeyError as exc: + await daily_usage.release(reservation) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API key is inactive", + ) from exc + except ExpiredAPIKeyError as exc: + await daily_usage.release(reservation) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API key has expired", + ) from exc + + record = updated_record + consumed_api_quota = True + + try: + response = await call_gemini_with_retry(request) + except Exception: + try: + await daily_usage.release(reservation) + except Exception: # pragma: no cover - defensive logging for rollbacks + logger.exception( + "Failed to roll back daily usage for bucket %s", + bucket, + ) + + if consumed_api_quota: + try: + await authenticator.release_usage(record.lookup_hash) + except Exception: # pragma: no cover - defensive logging for rollbacks + logger.exception( + "Failed to roll back usage for API key %s", record.lookup_hash + ) + raise + + logger.info( + "Gemini request completed for model %s (bucket=%s, consumed_free=%s)", + request.model, + bucket, + reservation.consumed_free if reservation else False, + ) return response.to_json_dict() diff --git a/src/utils/web_search/auth.py b/src/utils/web_search/auth.py index 4fd523ed..021c440d 100644 --- a/src/utils/web_search/auth.py +++ b/src/utils/web_search/auth.py @@ -286,6 +286,58 @@ async def reserve_usage( return record + async def consume_usage(self, lookup_hash: str) -> APIKeyRecord: + """Increment usage counter for a previously validated API key.""" + record = self._cache_lookup(lookup_hash) + + if not record: + try: + record = await self._repository.get_api_key(lookup_hash) + except APIKeyNotFoundError as exc: + raise InvalidAPIKeyError("API key not recognised") from exc + self._cache_store(record) + + if record.status != "active": + raise InactiveAPIKeyError("API key has been suspended") + + if record.expires_at and self._clock() >= record.expires_at: + self._cache.pop(lookup_hash, None) + raise ExpiredAPIKeyError("API key has expired") + + try: + updated_record = await self._repository.update_usage_counter(lookup_hash) + except APIKeyNotFoundError as exc: + self._cache.pop(lookup_hash, None) + raise InvalidAPIKeyError("API key not recognised") from exc + + self._cache_store(updated_record) + return updated_record + + async def release_usage(self, lookup_hash: str) -> APIKeyRecord: + """Rollback a previously reserved usage slot. + + Parameters + ---------- + lookup_hash : str + Lookup hash corresponding to the API key whose usage should be + decremented. + + Returns + ------- + APIKeyRecord + Updated record containing the decremented usage counter. + """ + try: + updated_record = await self._repository.decrement_usage_counter( + lookup_hash, + ) + except APIKeyNotFoundError as exc: # pragma: no cover - defensive branch + self._cache.pop(lookup_hash, None) + raise InvalidAPIKeyError("API key not recognised") from exc + + self._cache_store(updated_record) + return updated_record + async def create_api_key( self, *, diff --git a/src/utils/web_search/daily_usage.py b/src/utils/web_search/daily_usage.py new file mode 100644 index 00000000..0195719f --- /dev/null +++ b/src/utils/web_search/daily_usage.py @@ -0,0 +1,193 @@ +"""Track daily usage for Gemini models to account for free-tier allowances.""" + +import asyncio +import os +import random +from dataclasses import dataclass +from datetime import date, datetime, timezone +from typing import Any, Callable, Optional + + +try: + from google.api_core.exceptions import Aborted + from google.cloud.firestore_v1 import ( + SERVER_TIMESTAMP, + AsyncClient, + AsyncDocumentReference, + AsyncTransaction, + async_transactional, + ) +except ImportError: # pragma: no cover - imported dynamically in production + Aborted = RuntimeError # type: ignore + AsyncClient = Any # type: ignore + AsyncDocumentReference = Any # type: ignore + AsyncTransaction = Any # type: ignore + SERVER_TIMESTAMP = None # type: ignore + + def async_transactional(func): # type: ignore + """Passthrough decorator used when Firestore is not installed.""" + return func + + +DAILY_USAGE_COLLECTION = os.getenv("DAILY_USAGE_COLLECTION", "dailyUsageCounters") +DAILY_USAGE_MAX_RETRIES = int(os.getenv("DAILY_USAGE_MAX_RETRIES", "8")) +DAILY_USAGE_BASE_DELAY = float(os.getenv("DAILY_USAGE_BASE_DELAY", "0.05")) +DAILY_USAGE_MAX_DELAY = float(os.getenv("DAILY_USAGE_MAX_DELAY", "1.0")) + + +def _now() -> datetime: + """Return the current UTC timestamp.""" + return datetime.now(tz=timezone.utc) + + +def _retry_delay(attempt: int) -> float: + """Calculate a jittered exponential backoff for retries.""" + delay = DAILY_USAGE_BASE_DELAY * (2**attempt) + capped = min(delay, DAILY_USAGE_MAX_DELAY) + if capped <= 0: + return 0.0 + jitter = random.uniform(0, capped / 2) + return capped + jitter + + +@dataclass(slots=True) +class UsageReservation: + """Represents a single reserved request in the daily counter.""" + + bucket: str + day: date + consumed_free: bool + + +def _ensure_utc(value: Optional[datetime]) -> Optional[datetime]: + """Return a timezone-aware UTC datetime.""" + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +class DailyUsageRepository: + """Persist and update daily usage counters for Gemini buckets.""" + + def __init__( + self, + client: AsyncClient, + *, + collection_name: str = DAILY_USAGE_COLLECTION, + clock: Callable[[], datetime] = _now, + ) -> None: + """Initialise the repository with a Firestore client.""" + self._client = client + self._collection = collection_name + self._clock = clock + + def _document(self, bucket: str, day: date) -> AsyncDocumentReference: + """Return the document reference for the given bucket/day pair.""" + identifier = f"{bucket}:{day.isoformat()}" + return self._client.collection(self._collection).document(identifier) + + async def reserve(self, bucket: str, free_limit: int) -> UsageReservation: + """Reserve a usage slot for the given bucket. + + Parameters + ---------- + bucket : str + Logical identifier grouping the models that share a free allowance. + free_limit : int + Daily number of free requests for this bucket. ``0`` disables the + free tier so every call should fall through to API-key accounting. + + Returns + ------- + UsageReservation + Reservation describing whether the free allowance was consumed. + """ + free_limit = max(free_limit, 0) + today = self._clock().date() + doc_ref = self._document(bucket, today) + + @async_transactional + async def _increment( + transaction: AsyncTransaction, + reference: AsyncDocumentReference, + ) -> UsageReservation: + snapshot = await reference.get(transaction=transaction) + current_total = 0 + if snapshot.exists: + data: dict[str, Any] = snapshot.to_dict() or {} + current_total = int(data.get("total_count", 0)) + + transaction.update( + reference, + { + "total_count": current_total + 1, + "updated_at": SERVER_TIMESTAMP or _ensure_utc(self._clock()), + }, + ) + else: + transaction.set( + reference, + { + "bucket": bucket, + "date": today.isoformat(), + "total_count": 1, + "created_at": SERVER_TIMESTAMP or _ensure_utc(self._clock()), + "updated_at": SERVER_TIMESTAMP or _ensure_utc(self._clock()), + }, + ) + + consumed_free = free_limit > 0 and current_total < free_limit + + return UsageReservation( + bucket=bucket, day=today, consumed_free=consumed_free + ) + + attempts = 0 + while True: + try: + transaction = self._client.transaction() + return await _increment(transaction, doc_ref) + except (Aborted, ValueError): + if attempts >= DAILY_USAGE_MAX_RETRIES - 1: + raise + await asyncio.sleep(_retry_delay(attempts)) + attempts += 1 + + async def release(self, reservation: UsageReservation) -> None: + """Rollback a reservation when the downstream call fails.""" + doc_ref = self._document(reservation.bucket, reservation.day) + + @async_transactional + async def _decrement( + transaction: AsyncTransaction, + reference: AsyncDocumentReference, + ) -> None: + snapshot = await reference.get(transaction=transaction) + if not snapshot.exists: + return + + data: dict[str, Any] = snapshot.to_dict() or {} + current_total = int(data.get("total_count", 0)) + new_total = max(current_total - 1, 0) + + transaction.update( + reference, + { + "total_count": new_total, + "updated_at": SERVER_TIMESTAMP or _ensure_utc(self._clock()), + }, + ) + + attempts = 0 + while True: + try: + transaction = self._client.transaction() + await _decrement(transaction, doc_ref) + return + except (Aborted, ValueError): + if attempts >= DAILY_USAGE_MAX_RETRIES - 1: + raise + await asyncio.sleep(_retry_delay(attempts)) + attempts += 1 diff --git a/src/utils/web_search/db.py b/src/utils/web_search/db.py index 11b65ff3..dd9214fa 100644 --- a/src/utils/web_search/db.py +++ b/src/utils/web_search/db.py @@ -1,11 +1,15 @@ """Helpers for interacting with the Firestore-backed API key store.""" +import asyncio +import os +import random from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Literal, Optional try: + from google.api_core.exceptions import Aborted from google.cloud.firestore_v1 import ( SERVER_TIMESTAMP, AsyncClient, @@ -15,6 +19,7 @@ async_transactional, ) except ImportError: # pragma: no cover - imported at runtime in production + Aborted = RuntimeError # type: ignore AsyncClient = Any # type: ignore AsyncDocumentReference = Any # type: ignore DocumentSnapshot = Any # type: ignore @@ -26,6 +31,21 @@ def async_transactional(func): # type: ignore return func +USAGE_TRANSACTION_MAX_RETRIES = int(os.getenv("API_KEY_USAGE_MAX_RETRIES", "8")) +USAGE_TRANSACTION_BASE_DELAY = float(os.getenv("API_KEY_USAGE_BASE_DELAY", "0.05")) +USAGE_TRANSACTION_MAX_DELAY = float(os.getenv("API_KEY_USAGE_MAX_DELAY", "1.0")) + + +def _usage_retry_delay(attempt: int) -> float: + """Calculate a jittered backoff delay for Firestore retries.""" + base_delay = USAGE_TRANSACTION_BASE_DELAY * (2**attempt) + capped_delay = min(base_delay, USAGE_TRANSACTION_MAX_DELAY) + if capped_delay <= 0: + return 0.0 + jitter = random.uniform(0, capped_delay / 2) + return capped_delay + jitter + + class APIKeyNotFoundError(Exception): """Raised when an API key document cannot be found.""" @@ -307,7 +327,75 @@ async def _increment( expires_at=_ensure_timezone(data.get("expires_at")), ) - return await _increment(self._client.transaction(), doc_ref) + attempts = 0 + while True: + try: + return await _increment(self._client.transaction(), doc_ref) + except (Aborted, ValueError): + if attempts >= USAGE_TRANSACTION_MAX_RETRIES - 1: + raise + await asyncio.sleep(_usage_retry_delay(attempts)) + attempts += 1 + + async def decrement_usage_counter(self, lookup_hash: str) -> APIKeyRecord: + """Rollback the usage counter when a request ultimately fails. + + Parameters + ---------- + lookup_hash : str + SHA-256 digest corresponding to the key to update. + + Returns + ------- + APIKeyRecord + The API key record containing the updated usage counter. + """ + doc_ref = self._document(lookup_hash) + + @async_transactional + async def _decrement( + transaction: AsyncTransaction, + reference: AsyncDocumentReference, + ) -> APIKeyRecord: + snapshot = await reference.get(transaction=transaction) + if not snapshot.exists: + raise APIKeyNotFoundError(lookup_hash) + + data: dict[str, Any] = snapshot.to_dict() or {} + usage_count = max(int(data.get("usage_count", 0)) - 1, 0) + + transaction.update( + reference, + {"usage_count": usage_count}, + ) + + return APIKeyRecord( + lookup_hash=lookup_hash, + hashed_key=data["hashed_key"], + salt=data["salt"], + display_prefix=data.get("display_prefix", ""), + role=data.get("role", "user"), + owner=data.get("owner"), + status=data.get("status", "active"), + usage_count=usage_count, + usage_limit=int(data.get("usage_limit", 0)), + last_used_at=_ensure_timezone(data.get("last_used_at")), + created_at=_ensure_timezone(data.get("created_at")) + or datetime.now(tz=timezone.utc), + created_by=data.get("created_by", "system"), + metadata=data.get("metadata", {}), + expires_at=_ensure_timezone(data.get("expires_at")), + ) + + attempts = 0 + while True: + try: + return await _decrement(self._client.transaction(), doc_ref) + except (Aborted, ValueError): + if attempts >= USAGE_TRANSACTION_MAX_RETRIES - 1: + raise + await asyncio.sleep(_usage_retry_delay(attempts)) + attempts += 1 async def set_status(self, lookup_hash: str, status: Status) -> None: """Update the ``status`` field for an API key record. diff --git a/tests/test_web_search_auth.py b/tests/test_web_search_auth.py index 31fc1ea0..7afb4b40 100644 --- a/tests/test_web_search_auth.py +++ b/tests/test_web_search_auth.py @@ -66,6 +66,17 @@ async def update_usage_counter(self, lookup_hash: str) -> APIKeyRecord: self.records[lookup_hash] = updated return updated + async def decrement_usage_counter(self, lookup_hash: str) -> APIKeyRecord: + """Decrement usage counter.""" + if lookup_hash not in self.records: + raise APIKeyNotFoundError(lookup_hash) + + record = self.records[lookup_hash] + new_count = max(record.usage_count - 1, 0) + updated = replace(record, usage_count=new_count) + self.records[lookup_hash] = updated + return updated + async def set_status(self, lookup_hash: str, status: Status) -> None: """Set status.""" record = self.records[lookup_hash] @@ -110,6 +121,25 @@ async def test_reserve_usage_increments_counter() -> None: assert repository.records[record.lookup_hash].usage_count == 1 +@pytest.mark.asyncio +async def test_release_usage_rolls_back_counter() -> None: + """Ensure usage reservations can be rolled back after failures.""" + repository = FakeRepository() + authenticator = auth.APIKeyAuthenticator(repository, clock=fixed_clock) + + api_key, record = await authenticator.create_api_key( + role="user", + owner="owner-rollback", + usage_limit=5, + created_by="admin", + ) + + await authenticator.reserve_usage(api_key) + await authenticator.release_usage(record.lookup_hash) + + assert repository.records[record.lookup_hash].usage_count == 0 + + @pytest.mark.asyncio async def test_reserve_usage_without_consuming() -> None: """Ensure validation can occur without incrementing the counter.""" @@ -132,6 +162,70 @@ async def test_reserve_usage_without_consuming() -> None: assert repository.records[record.lookup_hash].usage_count == 0 +@pytest.mark.asyncio +async def test_consume_usage_increments_counter_after_validation() -> None: + """Ensure usage can be consumed after a non-consuming validation.""" + repository = FakeRepository() + authenticator = auth.APIKeyAuthenticator(repository, clock=fixed_clock) + + api_key, record = await authenticator.create_api_key( + role="user", + owner="owner-consume", + usage_limit=5, + created_by="admin", + ) + + await authenticator.reserve_usage(api_key, consume_usage=False) + updated_record = await authenticator.consume_usage(record.lookup_hash) + + assert updated_record.usage_count == 1 + assert repository.records[record.lookup_hash].usage_count == 1 + + +@pytest.mark.asyncio +async def test_consume_usage_respects_usage_limit() -> None: + """Ensure ``consume_usage`` propagates usage limit errors.""" + repository = FakeRepository() + authenticator = auth.APIKeyAuthenticator(repository, clock=fixed_clock) + + api_key, record = await authenticator.create_api_key( + role="user", + owner="owner-limit", + usage_limit=1, + created_by="admin", + ) + + await authenticator.reserve_usage(api_key, consume_usage=False) + repository.records[record.lookup_hash] = replace( + repository.records[record.lookup_hash], + usage_count=1, + ) + + with pytest.raises(UsageLimitExceededError): + await authenticator.consume_usage(record.lookup_hash) + + +@pytest.mark.asyncio +async def test_consume_usage_rejects_inactive_records() -> None: + """Ensure suspended keys cannot be consumed after validation.""" + repository = FakeRepository() + authenticator = auth.APIKeyAuthenticator(repository, clock=fixed_clock) + + api_key, record = await authenticator.create_api_key( + role="user", + owner="owner-inactive", + usage_limit=5, + created_by="admin", + ) + + await authenticator.reserve_usage(api_key, consume_usage=False) + await repository.set_status(record.lookup_hash, "suspended") + authenticator._cache.pop(record.lookup_hash, None) + + with pytest.raises(auth.InactiveAPIKeyError): + await authenticator.consume_usage(record.lookup_hash) + + @pytest.mark.asyncio async def test_reserve_usage_rejects_invalid_key() -> None: """Ensure invalid API keys raise ``InvalidAPIKeyError``."""