diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..7879ad0 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,92 @@ +# AGENTS.md + +## Local Workflow + +- Install/sync dependencies: `uv sync` +- Run all tests: `uv run pytest` +- Run a focused test file: `uv run pytest tests/path/to/test_file.py` +- Run lint checks: `uv run ruff check .` +- Auto-fix lint issues where possible: `uv run ruff check --fix .` +- Format code: `uv run ruff format .` + +## Before Opening a PR + +- Run `uv run ruff check .` +- Run `uv run ruff format --check .` (or `uv run ruff format .`) +- Run relevant tests for changed areas, then run `uv run pytest` if changes are broad + +## Commit Message Guidance + +- Keep commit messages short and specific. +- Use a title line of 50 characters or fewer. +- Wrap commit message body lines at 72 characters. +- Explain what changed and why. +- Do not list file-by-file changes that are obvious from the diff. +- Do not include any `Co-authored-by:` line. + +### Good examples + +- `Add shared HTTP transport helpers` +- `Move iter_coroutine to a dedicated module` +- `Fix async request hook header handling` + +## Iter-Coroutine + Base/Runtime Migration Pattern + +Use this as the default shape when refactoring sync+async modules to reduce +duplication. + +### Core principles + +- Keep public API stable: same exported names, signatures, return types, and + behavior. +- Make the internal core async-first. +- Make sync entrypoints thin wrappers over async core via `iter_coroutine(...)` + only when the wrapped coroutine is non-suspending in sync mode. +- Route HTTP through `vercel._http` clients/transports; avoid direct + `httpx.Client`/`httpx.AsyncClient` construction in refactored feature modules. + +### Recommended structure + +- Create a private async base class for shared logic: + - Example shape: `_BaseClient` with async methods for shared ops. + - Keep parsing/validation/result-shaping helpers in this layer. +- Add private sync/async concrete classes: + - Sync uses `SyncTransport(...)` and sync callbacks. + - Async uses `AsyncTransport(...)` and awaitable callbacks. +- Keep public sync functions as wrappers that call + `iter_coroutine(base_client.async_method(...))`. +- Keep public async functions as direct `await` on the same async methods. + +### Base + runtime split (for true runtime-specific behavior) + +When sync and async must differ materially (threading vs asyncio scheduling), +use a runtime contract and shared orchestration: + +- Define one runtime method name (for example, `upload(...)`). +- Implement two runtimes: + - blocking runtime: threadpool/locks/sync callback handling. + - async runtime: `asyncio.create_task`/`asyncio.wait`/awaitable callbacks. +- Keep common orchestration shared: + - validation + - chunk/part iteration helpers + - normalization/order of results + - final response shaping + +### `iter_coroutine` guardrails + +- Safe: sync wrappers around coroutines that complete without real suspension. +- Unsafe: coroutines that rely on event-loop scheduling (network awaits, + `asyncio.sleep`, task scheduling, etc.). +- For mixed callback paths, use explicit `inspect.isawaitable(...)` checks in + async code rather than forcing everything through `iter_coroutine`. + +### Testing expectations for migrations + +- Prefer integration-style tests with `respx` that verify real request flow and + sync/async parity. +- Do not rely only on monkeypatch tests that assert internal call shape. +- Validate before commit: + - `uv run ruff check .` + - `uv run ruff format --check .` + - targeted tests for changed modules + - `uv run pytest` when changes are broad diff --git a/README.md b/README.md index ed9af73..d6dd730 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,10 @@ Notes: Requires `BLOB_READ_WRITE_TOKEN` to be set as an env var or `token` to be set when constructing a client +`BlobClient` and `AsyncBlobClient` keep a long-lived HTTP transport for the life of +the client instance. Prefer `with BlobClient(...)` / `async with AsyncBlobClient(...)` +or call `close()` / `aclose()` explicitly when done. + #### Sync @@ -160,19 +164,17 @@ Requires `BLOB_READ_WRITE_TOKEN` to be set as an env var or `token` to be set wh ```python from vercel.blob import BlobClient -client = BlobClient() -# or BlobClient(token="...") - -# Create a folder entry, upload a local file, list, then download -client.create_folder("examples/assets", overwrite=True) -uploaded = client.upload_file( - "./README.md", - "examples/assets/readme-copy.txt", - access="public", - content_type="text/plain", -) -listing = client.list_objects(prefix="examples/assets/") -client.download_file(uploaded.url, "/tmp/readme-copy.txt", overwrite=True) +with BlobClient() as client: # or BlobClient(token="...") + # Create a folder entry, upload a local file, list, then download + client.create_folder("examples/assets", overwrite=True) + uploaded = client.upload_file( + "./README.md", + "examples/assets/readme-copy.txt", + access="public", + content_type="text/plain", + ) + listing = client.list_objects(prefix="examples/assets/") + client.download_file(uploaded.url, "/tmp/readme-copy.txt", overwrite=True) ``` Async usage: @@ -182,21 +184,20 @@ import asyncio from vercel.blob import AsyncBlobClient async def main(): - client = AsyncBlobClient() # uses BLOB_READ_WRITE_TOKEN from env - - # Upload bytes - uploaded = await client.put( - "examples/assets/hello.txt", - b"hello from python", - access="public", - content_type="text/plain", - ) - - # Inspect metadata, list, download bytes, then delete - meta = await client.head(uploaded.url) - listing = await client.list_objects(prefix="examples/assets/") - content = await client.get(uploaded.url) - await client.delete([b.url for b in listing.blobs]) + async with AsyncBlobClient() as client: # uses BLOB_READ_WRITE_TOKEN from env + # Upload bytes + uploaded = await client.put( + "examples/assets/hello.txt", + b"hello from python", + access="public", + content_type="text/plain", + ) + + # Inspect metadata, list, download bytes, then delete + meta = await client.head(uploaded.url) + listing = await client.list_objects(prefix="examples/assets/") + content = await client.get(uploaded.url) + await client.delete([b.url for b in listing.blobs]) asyncio.run(main()) ``` @@ -206,18 +207,17 @@ Synchronous usage: ```python from vercel.blob import BlobClient -client = BlobClient() # or BlobClient(token="...") - -# Create a folder entry, upload a local file, list, then download -client.create_folder("examples/assets", overwrite=True) -uploaded = client.upload_file( - "./README.md", - "examples/assets/readme-copy.txt", - access="public", - content_type="text/plain", -) -listing = client.list_objects(prefix="examples/assets/") -client.download_file(uploaded.url, "/tmp/readme-copy.txt", overwrite=True) +with BlobClient() as client: # or BlobClient(token="...") + # Create a folder entry, upload a local file, list, then download + client.create_folder("examples/assets", overwrite=True) + uploaded = client.upload_file( + "./README.md", + "examples/assets/readme-copy.txt", + access="public", + content_type="text/plain", + ) + listing = client.list_objects(prefix="examples/assets/") + client.download_file(uploaded.url, "/tmp/readme-copy.txt", overwrite=True) ``` #### Multipart Uploads @@ -253,17 +253,20 @@ A middle-ground that provides a clean API while giving you control over parts an from vercel.blob import BlobClient, create_multipart_uploader # Create the uploader (initializes the upload) -client = BlobClient() -uploader = client.create_multipart_uploader("large-file.bin", content_type="application/octet-stream") +with BlobClient() as client: + uploader = client.create_multipart_uploader( + "large-file.bin", + content_type="application/octet-stream", + ) -# Upload parts (you control when and how) -parts = [] -for i, chunk in enumerate(chunks, start=1): - part = uploader.upload_part(i, chunk) - parts.append(part) + # Upload parts (you control when and how) + parts = [] + for i, chunk in enumerate(chunks, start=1): + part = uploader.upload_part(i, chunk) + parts.append(part) -# Complete the upload -result = uploader.complete(parts) + # Complete the upload + result = uploader.complete(parts) ``` Async version with concurrent uploads: @@ -271,15 +274,15 @@ Async version with concurrent uploads: ```python from vercel.blob import AsyncBlobClient, create_multipart_uploader_async -client = AsyncBlobClient() -uploader = await client.create_multipart_uploader("large-file.bin") +async with AsyncBlobClient() as client: + uploader = await client.create_multipart_uploader("large-file.bin") -# Upload parts concurrently -tasks = [uploader.upload_part(i, chunk) for i, chunk in enumerate(chunks, start=1)] -parts = await asyncio.gather(*tasks) + # Upload parts concurrently + tasks = [uploader.upload_part(i, chunk) for i, chunk in enumerate(chunks, start=1)] + parts = await asyncio.gather(*tasks) -# Complete -result = await uploader.complete(parts) + # Complete + result = await uploader.complete(parts) ``` The uploader pattern is ideal when you: @@ -347,4 +350,4 @@ uv run ruff format --check && uv run ruff check . && uv run mypy src && uv run p ## License -MIT \ No newline at end of file +MIT diff --git a/examples/blob_storage_multipart.py b/examples/blob_storage_multipart.py index ae7ab26..e5740f8 100644 --- a/examples/blob_storage_multipart.py +++ b/examples/blob_storage_multipart.py @@ -202,7 +202,10 @@ def comparison_example(): print("Please set it to run these examples.") exit(1) + async def async_main(): + await async_example() + await async_with_file_example() + comparison_example() sync_example() - asyncio.run(async_example()) - asyncio.run(async_with_file_example()) + asyncio.run(async_main()) diff --git a/pyproject.toml b/pyproject.toml index 870b989..853da58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio", "pytest-xdist", + "respx>=0.21.0", "mypy", "build", "twine", @@ -53,9 +54,14 @@ vercel = ["py.typed"] testpaths = ["tests"] addopts = "-q" asyncio_mode = "auto" +markers = [ + "live: requires live API credentials (VERCEL_TOKEN, BLOB_READ_WRITE_TOKEN, etc.)", +] [tool.mypy] ignore_missing_imports = true +mypy_path = "src" +explicit_package_bases = true [tool.ruff] line-length = 100 diff --git a/src/vercel/_internal/__init__.py b/src/vercel/_internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/vercel/blob/utils.py b/src/vercel/_internal/blob/__init__.py similarity index 93% rename from src/vercel/blob/utils.py rename to src/vercel/_internal/blob/__init__.py index 851ff1a..16b4622 100644 --- a/src/vercel/blob/utils.py +++ b/src/vercel/_internal/blob/__init__.py @@ -4,14 +4,33 @@ import os import time import uuid -from collections.abc import Awaitable, Callable, Iterable -from dataclasses import dataclass +from collections.abc import Callable, Iterable from datetime import datetime -from typing import Any, Literal, Protocol, TypedDict +from typing import Any, Protocol, TypedDict +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse -from .errors import BlobError, BlobNoTokenProvidedError -Access = Literal["public", "private"] +def get_download_url(blob_url: str) -> str: + try: + parsed = urlparse(blob_url) + q = dict(parse_qsl(parsed.query)) + q["download"] = "1" + new_query = urlencode(q) + return urlunparse( + ( + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + new_query, + parsed.fragment, + ) + ) + except Exception: + # Fallback: naive append + sep = "&" if "?" in blob_url else "?" + return f"{blob_url}{sep}download=1" + DEFAULT_VERCEL_BLOB_API_URL = "https://vercel.com/api/blob" MAXIMUM_PATHNAME_LENGTH = 950 @@ -100,6 +119,8 @@ def extract_store_id_from_token(token: str) -> str: def validate_path(path: str) -> None: + from vercel._internal.blob.errors import BlobError + if not path: raise BlobError("path is required") if len(path) > MAXIMUM_PATHNAME_LENGTH: @@ -109,13 +130,15 @@ def validate_path(path: str) -> None: raise BlobError(f'path cannot contain "{invalid}", please encode it if needed') -def validate_access(access: Access) -> Access: +def validate_access(access: str) -> str: + from vercel._internal.blob.errors import BlobError + if access not in ("public", "private"): raise BlobError('access must be "public" or "private"') return access -def construct_blob_url(store_id: str, pathname: str, access: Access) -> str: +def construct_blob_url(store_id: str, pathname: str, access: str) -> str: """Construct a blob storage URL based on access type. Public: https://{storeId}.public.blob.vercel-storage.com/{pathname} @@ -150,19 +173,6 @@ def compute_body_length(body: Any) -> int: return 0 -# Progress -@dataclass -class UploadProgressEvent: - loaded: int - total: int - percentage: float - - -OnUploadProgressCallback = ( - Callable[[UploadProgressEvent], None] | Callable[[UploadProgressEvent], Awaitable[None]] -) - - class SupportsRead(Protocol): def read(self, size: int = -1) -> bytes: # pragma: no cover - Protocol ... @@ -178,7 +188,7 @@ class StreamingBodyWithProgress: def __init__( self, body: bytes | bytearray | memoryview | str | SupportsRead | Iterable[bytes], - on_progress: OnUploadProgressCallback | None, + on_progress: Callable | None, chunk_size: int = 64 * 1024, total: int | None = None, ) -> None: @@ -229,6 +239,8 @@ def _yield_bytes(self, data: bytes) -> Iterable[bytes]: def _emit_progress(self) -> None: if self._on_progress: + from vercel._internal.blob.types import UploadProgressEvent + total = self._total if self._total else self._loaded percentage = round((self._loaded / total) * 100, 2) if total else 0.0 self._on_progress( @@ -237,6 +249,8 @@ def _emit_progress(self) -> None: async def _emit_progress_async(self) -> None: if self._on_progress: + from vercel._internal.blob.types import UploadProgressEvent + total = self._total if self._total else self._loaded percentage = round((self._loaded / total) * 100, 2) if total else 0.0 result = self._on_progress( @@ -313,30 +327,6 @@ def parse_datetime(value: str) -> datetime: return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ") -def get_download_url(blob_url: str) -> str: - try: - from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse - - parsed = urlparse(blob_url) - q = dict(parse_qsl(parsed.query)) - q["download"] = "1" - new_query = urlencode(q) - return urlunparse( - ( - parsed.scheme, - parsed.netloc, - parsed.path, - parsed.params, - new_query, - parsed.fragment, - ) - ) - except Exception: - # Fallback: naive append - sep = "&" if "?" in blob_url else "?" - return f"{blob_url}{sep}download=1" - - # TypedDict with real HTTP header keys. Use functional syntax to allow hyphens. PutHeaders = TypedDict( "PutHeaders", @@ -356,7 +346,7 @@ def create_put_headers( add_random_suffix: bool | None = None, allow_overwrite: bool | None = None, cache_control_max_age: int | None = None, - access: Access | None = None, + access: str | None = None, ) -> PutHeaders: headers: PutHeaders = {} if content_type: @@ -373,6 +363,8 @@ def create_put_headers( def ensure_token(token: str | None) -> str: + from vercel._internal.blob.errors import BlobNoTokenProvidedError + token = token or os.getenv("BLOB_READ_WRITE_TOKEN") or os.getenv("VERCEL_BLOB_READ_WRITE_TOKEN") if not token: raise BlobNoTokenProvidedError() diff --git a/src/vercel/_internal/blob/core.py b/src/vercel/_internal/blob/core.py new file mode 100644 index 0000000..f1900ab --- /dev/null +++ b/src/vercel/_internal/blob/core.py @@ -0,0 +1,1248 @@ +from __future__ import annotations + +import asyncio +import inspect +import os +import time +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from typing import Any, Literal, cast +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +import httpx + +from vercel._internal.blob import ( + PutHeaders, + StreamingBodyWithProgress, + compute_body_length, + construct_blob_url, + create_put_headers, + debug, + ensure_token, + extract_store_id_from_token, + get_api_url, + get_api_version, + get_download_url, + get_proxy_through_alternative_api_header_from_env, + get_retries, + is_url, + make_request_id, + parse_datetime, + parse_rfc7231_retry_after, + should_use_x_content_length, + validate_access, + validate_path, +) +from vercel._internal.blob.errors import ( + BlobAccessError, + BlobClientTokenExpiredError, + BlobContentTypeNotAllowedError, + BlobError, + BlobFileTooLargeError, + BlobInvalidResponseJSONError, + BlobNotFoundError, + BlobPathnameMismatchError, + BlobServiceNotAvailable, + BlobServiceRateLimited, + BlobStoreNotFoundError, + BlobStoreSuspendedError, + BlobUnexpectedResponseContentTypeError, + BlobUnknownError, +) +from vercel._internal.blob.multipart import ( + DEFAULT_PART_SIZE, + MultipartClient, + MultipartUploadSession, + create_async_multipart_upload_runtime, + create_sync_multipart_upload_runtime, + order_uploaded_parts, + prepare_upload_headers, + shape_complete_upload_result, + validate_part_size, +) +from vercel._internal.blob.types import ( + Access, + CreateFolderResult as CreateFolderResultType, + GetBlobResult as GetBlobResultType, + HeadBlobResult as HeadBlobResultType, + ListBlobItem, + ListBlobResult as ListBlobResultType, + PutBlobResult as PutBlobResultType, + UploadProgressEvent, +) +from vercel._internal.http import ( + AsyncTransport, + BaseTransport, + JSONBody, + RawBody, + SyncTransport, + create_base_async_client, + create_base_client, +) +from vercel._internal.iter_coroutine import iter_coroutine +from vercel._internal.telemetry.tracker import track + +BlobProgressCallback = ( + Callable[[UploadProgressEvent], None] | Callable[[UploadProgressEvent], Awaitable[None]] +) +DownloadProgressCallback = ( + Callable[[int, int | None], None] | Callable[[int, int | None], Awaitable[None]] +) +SleepFn = Callable[[float], Awaitable[None] | None] +PUT_BODY_OBJECT_ERROR = ( + "Body must be a string, buffer or stream. " + "You sent a plain object, double check what you're trying to upload." +) + + +def _sync_sleep(seconds: float) -> None: + time.sleep(seconds) + + +async def _await_if_necessary(value: Any) -> Any: + if inspect.isawaitable(value): + return await cast(Awaitable[Any], value) + return value + + +def map_blob_error(response: httpx.Response) -> tuple[str, BlobError]: + try: + data = response.json() + except Exception: + data = {} + + code = (data.get("error") or {}).get("code") or "unknown_error" + message = (data.get("error") or {}).get("message") or "" + + if "contentType" in message and "is not allowed" in message: + code = "content_type_not_allowed" + if '"pathname"' in message and "does not match the token payload" in message: + code = "client_token_pathname_mismatch" + if message == "Token expired": + code = "client_token_expired" + if "the file length cannot be greater than" in message: + code = "file_too_large" + + if code == "store_suspended": + return code, BlobStoreSuspendedError() + if code == "forbidden": + return code, BlobAccessError() + if code == "content_type_not_allowed": + return code, BlobContentTypeNotAllowedError(message or "") + if code == "client_token_pathname_mismatch": + return code, BlobPathnameMismatchError(message or "") + if code == "client_token_expired": + return code, BlobClientTokenExpiredError() + if code == "file_too_large": + return code, BlobFileTooLargeError(message or "") + if code == "not_found": + return code, BlobNotFoundError() + if code == "store_not_found": + return code, BlobStoreNotFoundError() + if code == "bad_request": + return code, BlobError(message or "Bad request") + if code == "service_unavailable": + return code, BlobServiceNotAvailable() + if code == "rate_limited": + seconds = parse_rfc7231_retry_after(response.headers.get("retry-after")) + return code, BlobServiceRateLimited(seconds) + + return code, BlobUnknownError() + + +def should_retry(code: str) -> bool: + return code in {"unknown_error", "service_unavailable", "internal_server_error"} + + +def is_network_error(exc: Exception) -> bool: + return isinstance(exc, httpx.TransportError) + + +def decode_blob_response(response: httpx.Response) -> Any: + try: + return response.json() + except Exception: + return response.text + + +def _is_json_content_type(content_type: str) -> bool: + media_type = content_type.split(";", 1)[0].strip().lower() + return media_type == "application/json" or media_type.endswith("+json") + + +def decode_blob_response_json(response: httpx.Response) -> Any: + content_type = response.headers.get("content-type", "") + if not _is_json_content_type(content_type): + raise BlobUnexpectedResponseContentTypeError(content_type or None) + + try: + return response.json() + except Exception as exc: + raise BlobInvalidResponseJSONError() from exc + + +async def _emit_progress( + callback: BlobProgressCallback | None, + event: UploadProgressEvent, + *, + await_callback: bool, +) -> None: + if callback is None: + return + + result = callback(event) + if await_callback and inspect.isawaitable(result): + await cast(Awaitable[None], result) + + +async def _emit_download_progress( + callback: DownloadProgressCallback | None, + loaded: int, + total: int | None, + *, + await_callback: bool, +) -> None: + if callback is None: + return + + result = callback(loaded, total) + if await_callback and inspect.isawaitable(result): + await cast(Awaitable[None], result) + + +async def _sleep_with_backoff( + sleep_fn: SleepFn, + attempt: int, +) -> None: + delay = min(2**attempt * 0.1, 2.0) + result = sleep_fn(delay) + if inspect.isawaitable(result): + await cast(Awaitable[None], result) + + +def _build_headers( + *, + token: str, + request_id: str, + attempt: int, + extra_headers: dict[str, str], + request_headers: dict[str, str], + send_body_length: bool, + total_length: int, + api_version: str, +) -> dict[str, str]: + final_headers = { + "authorization": f"Bearer {token}", + "x-api-blob-request-id": request_id, + "x-api-blob-request-attempt": str(attempt), + "x-api-version": api_version, + **extra_headers, + } + if request_headers: + final_headers.update(request_headers) + if send_body_length and total_length: + final_headers["x-content-length"] = str(total_length) + return final_headers + + +def _build_request_body( + body: Any, + *, + on_upload_progress: BlobProgressCallback | None, + async_content: bool, +) -> JSONBody | RawBody | None: + if body is None: + return None + + if isinstance(body, (bytes, bytearray, memoryview, str)) or hasattr(body, "read"): + wrapped = StreamingBodyWithProgress( + cast(bytes | bytearray | memoryview | str | Any, body), + on_upload_progress, + ) + content = wrapped.__aiter__() if async_content else wrapped + return RawBody(content) + + return JSONBody(body) + + +def get_telemetry_size_bytes(body: Any) -> int | None: + if isinstance(body, (bytes, bytearray)): + return len(body) + if isinstance(body, str): + return len(body.encode()) + return None + + +def _validate_put_inputs(path: str, body: Any, access: str) -> None: + validate_path(path) + validate_access(access) + if body is None: + raise BlobError("body is required") + if isinstance(body, dict): + raise BlobError(PUT_BODY_OBJECT_ERROR) + + +def normalize_delete_urls(url_or_path: str | Iterable[str]) -> list[str]: + if isinstance(url_or_path, Iterable) and not isinstance(url_or_path, (str, bytes)): + return [str(url) for url in url_or_path] + return [str(url_or_path)] + + +def build_put_blob_result(raw: dict[str, Any]) -> PutBlobResultType: + return PutBlobResultType( + url=raw["url"], + download_url=raw["downloadUrl"], + pathname=raw["pathname"], + content_type=raw["contentType"], + content_disposition=raw["contentDisposition"], + ) + + +def build_head_blob_result(resp: dict[str, Any]) -> HeadBlobResultType: + uploaded_at = ( + parse_datetime(resp["uploadedAt"]) + if isinstance(resp.get("uploadedAt"), str) + else resp["uploadedAt"] + ) + return HeadBlobResultType( + size=resp["size"], + uploaded_at=uploaded_at, + pathname=resp["pathname"], + content_type=resp["contentType"], + content_disposition=resp["contentDisposition"], + url=resp["url"], + download_url=resp["downloadUrl"], + cache_control=resp["cacheControl"], + ) + + +def build_list_blob_result(resp: dict[str, Any]) -> ListBlobResultType: + blobs_list: list[ListBlobItem] = [] + for blob in resp.get("blobs", []): + uploaded_at = ( + parse_datetime(blob["uploadedAt"]) + if isinstance(blob.get("uploadedAt"), str) + else blob["uploadedAt"] + ) + blobs_list.append( + ListBlobItem( + url=blob["url"], + download_url=blob["downloadUrl"], + pathname=blob["pathname"], + size=blob["size"], + uploaded_at=uploaded_at, + ) + ) + return ListBlobResultType( + blobs=blobs_list, + cursor=resp.get("cursor"), + has_more=resp.get("hasMore", False), + folders=resp.get("folders"), + ) + + +def build_create_folder_result(raw: dict[str, Any]) -> CreateFolderResultType: + return CreateFolderResultType(pathname=raw["pathname"], url=raw["url"]) + + +def build_list_params( + *, + limit: int | None = None, + prefix: str | None = None, + cursor: str | None = None, + mode: str | None = None, +) -> dict[str, Any]: + params: dict[str, Any] = {} + if limit is not None: + params["limit"] = int(limit) + if prefix is not None: + params["prefix"] = prefix + if cursor is not None: + params["cursor"] = cursor + if mode is not None: + params["mode"] = mode + return params + + +def _resolve_page_limit( + *, + batch_size: int | None, + limit: int | None, + yielded_count: int, +) -> tuple[bool, int | None]: + page_limit = batch_size + if limit is None: + return False, page_limit + + remaining = limit - yielded_count + if remaining <= 0: + return True, None + if page_limit is None or page_limit > remaining: + page_limit = remaining + return False, page_limit + + +def _get_next_cursor(page: ListBlobResultType) -> str | None: + if not page.has_more or not page.cursor: + return None + return page.cursor + + +def _build_cache_bypass_url(blob_url: str) -> str: + parsed = urlparse(blob_url) + params = parse_qs(parsed.query) + params["cache"] = ["0"] + query = urlencode(params, doseq=True) + return urlunparse( + ( + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + query, + parsed.fragment, + ) + ) + + +def parse_last_modified(value: str | None) -> datetime: + if not value: + return datetime.now(tz=timezone.utc) + try: + return parsedate_to_datetime(value) + except (ValueError, TypeError): + pass + try: + return parse_datetime(value) + except (ValueError, TypeError): + return datetime.now(tz=timezone.utc) + + +class BlobRequestClient: + _transport: BaseTransport + _sleep_fn: SleepFn + _await_progress_callback: bool + _async_content: bool + + def __init__( + self, + *, + transport: BaseTransport, + sleep_fn: SleepFn = asyncio.sleep, + await_progress_callback: bool = True, + async_content: bool = True, + ) -> None: + self._transport = transport + self._sleep_fn = sleep_fn + self._await_progress_callback = await_progress_callback + self._async_content = async_content + + @property + def transport(self) -> BaseTransport: + return self._transport + + @property + def await_progress_callback(self) -> bool: + return self._await_progress_callback + + def close(self) -> None: + if isinstance(self._transport, SyncTransport): + self._transport.close() + + async def aclose(self) -> None: + if isinstance(self._transport, AsyncTransport): + await self._transport.aclose() + + async def request_api( + self, + pathname: str, + method: str, + *, + token: str | None = None, + headers: PutHeaders | dict[str, str] | None = None, + params: dict[str, Any] | None = None, + body: Any = None, + on_upload_progress: BlobProgressCallback | None = None, + timeout: float | None = None, + decode_mode: Literal["json", "any", "none"] = "json", + ) -> Any: + token = ensure_token(token) + store_id = extract_store_id_from_token(token) + request_id = make_request_id(store_id) + retries = get_retries() + api_version = get_api_version() + extra_headers = get_proxy_through_alternative_api_header_from_env() + request_headers = cast(dict[str, str], headers or {}) + + send_body_length = bool(on_upload_progress) or should_use_x_content_length() + total_length = compute_body_length(body) if send_body_length else 0 + + if on_upload_progress: + await _emit_progress( + on_upload_progress, + UploadProgressEvent(loaded=0, total=total_length, percentage=0.0), + await_callback=self._await_progress_callback, + ) + + url = get_api_url(pathname) + effective_timeout = timeout if timeout is not None else 30.0 + + for attempt in range(retries + 1): + try: + final_headers = _build_headers( + token=token, + request_id=request_id, + attempt=attempt, + extra_headers=extra_headers, + request_headers=request_headers, + send_body_length=send_body_length, + total_length=total_length, + api_version=api_version, + ) + request_body = _build_request_body( + body, + on_upload_progress=on_upload_progress, + async_content=self._async_content, + ) + resp = await self._transport.send( + method=method, + path=url, + headers=final_headers, + params=params, + body=request_body, + timeout=effective_timeout, + ) + + if 200 <= resp.status_code < 300: + if on_upload_progress: + await _emit_progress( + on_upload_progress, + UploadProgressEvent( + loaded=total_length or 0, + total=total_length or 0, + percentage=100.0, + ), + await_callback=self._await_progress_callback, + ) + if decode_mode == "none": + return None + if decode_mode == "json": + return decode_blob_response_json(resp) + return decode_blob_response(resp) + + code, mapped = map_blob_error(resp) + if should_retry(code) and attempt < retries: + debug(f"retrying API request to {pathname}", code) + await _sleep_with_backoff(self._sleep_fn, attempt) + continue + raise mapped + except Exception as exc: + if is_network_error(exc) and attempt < retries: + debug(f"retrying API request to {pathname}", str(exc)) + await _sleep_with_backoff(self._sleep_fn, attempt) + continue + if isinstance(exc, httpx.HTTPError): + raise BlobUnknownError() from exc + raise + + raise BlobUnknownError() + + +def create_sync_request_client(timeout: float = 30.0) -> BlobRequestClient: + transport = SyncTransport(create_base_client(timeout=timeout)) + return BlobRequestClient( + transport=transport, + sleep_fn=_sync_sleep, + await_progress_callback=False, + async_content=False, + ) + + +def create_async_request_client(timeout: float = 30.0) -> BlobRequestClient: + transport = AsyncTransport(create_base_async_client(timeout=timeout)) + return BlobRequestClient( + transport=transport, + ) + + +class BaseBlobOpsClient: + def __init__( + self, + *, + request_client: BlobRequestClient, + multipart_client: MultipartClient, + multipart_runtime: Any, + ) -> None: + self._request_client = request_client + self._multipart_client = multipart_client + self._multipart_runtime = multipart_runtime + + def _stream_download_chunks(self, response: httpx.Response) -> AsyncIterator[bytes]: + raise NotImplementedError + + async def _close_response(self, response: httpx.Response) -> None: + raise NotImplementedError + + async def _close_download_response(self, response: httpx.Response) -> None: + await self._close_response(response) + + def _make_upload_part_fn(self) -> Any: + raise NotImplementedError + + async def _multipart_upload( + self, + path: str, + body: Any, + *, + access: Access, + content_type: str | None = None, + add_random_suffix: bool = False, + overwrite: bool = False, + cache_control_max_age: int | None = None, + token: str | None = None, + on_upload_progress: BlobProgressCallback | None = None, + ) -> dict[str, Any]: + headers = prepare_upload_headers( + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + ) + part_size = validate_part_size(DEFAULT_PART_SIZE) + + create_response = await self._multipart_client.create_multipart_upload( + path, + headers, + token=token, + ) + session = MultipartUploadSession( + upload_id=create_response["uploadId"], + key=create_response["key"], + path=path, + headers=headers, + token=token, + ) + + total = compute_body_length(body) + parts = cast( + list[dict[str, Any]], + await _await_if_necessary( + self._multipart_runtime.upload( + session=session, + body=body, + part_size=part_size, + total=total, + on_upload_progress=on_upload_progress, + upload_part_fn=self._make_upload_part_fn(), + ) + ), + ) + ordered_parts = order_uploaded_parts(parts) + + complete_response = await self._multipart_client.complete_multipart_upload( + upload_id=session.upload_id, + key=session.key, + path=session.path, + headers=session.headers, + token=session.token, + parts=ordered_parts, + ) + return shape_complete_upload_result(complete_response) + + async def put_blob( + self, + path: str, + body: Any, + *, + access: Access, + content_type: str | None, + add_random_suffix: bool, + overwrite: bool, + cache_control_max_age: int | None, + token: str | None, + multipart: bool, + on_upload_progress: BlobProgressCallback | None, + ) -> tuple[PutBlobResultType, bool]: + token = ensure_token(token) + _validate_put_inputs(path, body, access) + + headers = create_put_headers( + content_type=content_type, + add_random_suffix=add_random_suffix, + allow_overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + access=access, + ) + + if multipart: + raw = await self._multipart_upload( + path, + body, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=token, + on_upload_progress=on_upload_progress, + ) + result = build_put_blob_result(raw) + track( + "blob_put", + token=token, + access=access, + content_type=content_type, + multipart=True, + size_bytes=get_telemetry_size_bytes(body), + ) + return result, True + + raw = cast( + dict[str, Any], + await self._request_client.request_api( + "", + "PUT", + token=token, + headers=headers, + params={"pathname": path}, + body=body, + on_upload_progress=on_upload_progress, + ), + ) + result = build_put_blob_result(raw) + track( + "blob_put", + token=token, + access=access, + content_type=content_type, + multipart=False, + size_bytes=get_telemetry_size_bytes(body), + ) + return result, False + + async def delete_blob( + self, + urls: list[str], + *, + token: str, + ) -> int: + await self._request_client.request_api( + "/delete", + "POST", + token=token, + headers={"content-type": "application/json"}, + body={"urls": urls}, + decode_mode="none", + ) + track("blob_delete", token=token, count=len(urls)) + return len(urls) + + async def head_blob( + self, + url_or_path: str, + *, + token: str | None, + ) -> HeadBlobResultType: + token = ensure_token(token) + resp = cast( + dict[str, Any], + await self._request_client.request_api( + "", + "GET", + token=token, + params={"url": url_or_path}, + ), + ) + return build_head_blob_result(resp) + + async def get_blob( + self, + url_or_path: str, + *, + access: Access, + token: str | None, + timeout: float | None, + use_cache: bool, + if_none_match: str | None, + default_timeout: float, + ) -> GetBlobResultType: + token = ensure_token(token) + validate_access(access) + target_url = url_or_path + pathname: str + download_url: str | None = None + if not is_url(target_url): + pathname = target_url.lstrip("/") + store_id = extract_store_id_from_token(token) + if store_id: + target_url = construct_blob_url(store_id, pathname, access) + else: + head_result = await self.head_blob(target_url, token=token) + target_url = head_result.url + pathname = head_result.pathname + download_url = head_result.download_url + else: + pathname = urlparse(target_url).path.lstrip("/") + if download_url is None: + download_url = get_download_url(target_url) + if not use_cache: + target_url = _build_cache_bypass_url(target_url) + + effective_timeout = timeout or default_timeout + headers: dict[str, str] = {} + if access == "private": + headers["authorization"] = f"Bearer {token}" + if if_none_match: + headers["if-none-match"] = if_none_match + response: httpx.Response | None = None + + try: + response = await self._request_client.transport.send( + "GET", + target_url, + headers=headers, + timeout=effective_timeout, + follow_redirects=True, + ) + if response.status_code == 404: + raise BlobNotFoundError() + if response.status_code == 304: + return GetBlobResultType( + url=target_url, + download_url=download_url, + pathname=pathname, + content_type=None, + size=None, + content_disposition=response.headers.get("content-disposition", ""), + cache_control=response.headers.get("cache-control", ""), + uploaded_at=parse_last_modified(response.headers.get("last-modified")), + etag=response.headers.get("etag", ""), + content=b"", + status_code=304, + ) + response.raise_for_status() + content_length = response.headers.get("content-length") + return GetBlobResultType( + url=target_url, + download_url=download_url, + pathname=pathname, + content_type=response.headers.get("content-type", "application/octet-stream"), + size=int(content_length) if content_length else len(response.content), + content_disposition=response.headers.get("content-disposition", ""), + cache_control=response.headers.get("cache-control", ""), + uploaded_at=parse_last_modified(response.headers.get("last-modified")), + etag=response.headers.get("etag", ""), + content=response.content, + status_code=response.status_code, + ) + except httpx.HTTPStatusError as exc: + if exc.response is not None and exc.response.status_code == 404: + raise BlobNotFoundError() from exc + raise + finally: + if response is not None: + await self._close_response(response) + + async def copy_blob( + self, + src_path: str, + dst_path: str, + *, + access: Access, + content_type: str | None, + add_random_suffix: bool, + overwrite: bool, + cache_control_max_age: int | None, + token: str | None, + ) -> PutBlobResultType: + token = ensure_token(token) + validate_path(dst_path) + validate_access(access) + + src_url = src_path + if not is_url(src_url): + src_url = (await self.head_blob(src_url, token=token)).url + + headers = create_put_headers( + content_type=content_type, + add_random_suffix=add_random_suffix, + allow_overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + access=access, + ) + raw = cast( + dict[str, Any], + await self._request_client.request_api( + "", + "PUT", + token=token, + headers=headers, + params={"pathname": str(dst_path), "fromUrl": src_url}, + ), + ) + return build_put_blob_result(raw) + + async def create_folder( + self, + path: str, + *, + token: str | None, + overwrite: bool, + ) -> CreateFolderResultType: + token = ensure_token(token) + folder_path = path if path.endswith("/") else f"{path}/" + headers = create_put_headers( + add_random_suffix=False, + allow_overwrite=overwrite, + ) + raw = cast( + dict[str, Any], + await self._request_client.request_api( + "", + "PUT", + token=token, + headers=headers, + params={"pathname": folder_path}, + ), + ) + return build_create_folder_result(raw) + + async def upload_file( + self, + local_path: str | os.PathLike, + path: str, + *, + access: Access, + content_type: str | None, + add_random_suffix: bool, + overwrite: bool, + cache_control_max_age: int | None, + token: str | None, + multipart: bool, + on_upload_progress: BlobProgressCallback | None, + missing_local_path_error: str, + ) -> PutBlobResultType: + token = ensure_token(token) + if not local_path: + raise BlobError(missing_local_path_error) + if not path: + raise BlobError("path is required") + + source_path = os.fspath(local_path) + if not os.path.exists(source_path): + raise BlobError("local_path does not exist") + if not os.path.isfile(source_path): + raise BlobError("local_path is not a file") + + size_bytes = os.path.getsize(source_path) + use_multipart = multipart or (size_bytes > 5 * 1024 * 1024) + + with open(source_path, "rb") as f: + result, _ = await self.put_blob( + path, + f, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=token, + multipart=use_multipart, + on_upload_progress=on_upload_progress, + ) + return result + + async def download_file( + self, + url_or_path: str, + local_path: str | os.PathLike, + *, + access: Access, + token: str | None, + timeout: float | None, + overwrite: bool, + create_parents: bool, + progress: DownloadProgressCallback | None, + ) -> str: + token = ensure_token(token) + validate_access(access) + if is_url(url_or_path): + target_url = get_download_url(url_or_path) + elif store_id := extract_store_id_from_token(token): + blob_url = construct_blob_url(store_id, url_or_path.lstrip("/"), access) + target_url = get_download_url(blob_url) + else: + meta = await self.head_blob(url_or_path, token=token) + target_url = meta.download_url or meta.url + + dst = os.fspath(local_path) + if not overwrite and os.path.exists(dst): + raise BlobError("destination exists; pass overwrite=True to replace it") + if create_parents: + os.makedirs(os.path.dirname(dst) or ".", exist_ok=True) + + tmp = dst + ".part" + bytes_read = 0 + effective_timeout = timeout or 120.0 + headers: dict[str, str] = {} + if access == "private": + headers["authorization"] = f"Bearer {token}" + response: httpx.Response | None = None + + try: + response = await self._request_client.transport.send( + "GET", + target_url, + headers=headers, + timeout=effective_timeout, + follow_redirects=True, + stream=True, + ) + if response.status_code == 404: + raise BlobNotFoundError() + response.raise_for_status() + total = int(response.headers.get("Content-Length", "0")) or None + + with open(tmp, "wb") as f: + async for chunk in self._stream_download_chunks(response): + if not chunk: + continue + f.write(chunk) + bytes_read += len(chunk) + await _emit_download_progress( + progress, + bytes_read, + total, + await_callback=self._request_client.await_progress_callback, + ) + + os.replace(tmp, dst) + except Exception: + if os.path.exists(tmp): + os.remove(tmp) + raise + finally: + if response is not None: + await self._close_download_response(response) + + return dst + + +class SyncBlobOpsClient(BaseBlobOpsClient): + def __init__(self, *, timeout: float = 30.0) -> None: + request_client = create_sync_request_client(timeout) + multipart_client = MultipartClient(request_client) + super().__init__( + request_client=request_client, + multipart_client=multipart_client, + multipart_runtime=create_sync_multipart_upload_runtime(), + ) + + def close(self) -> None: + self._request_client.close() + + def _make_upload_part_fn(self) -> Any: + return lambda **kw: iter_coroutine(self._multipart_client.upload_part(**kw)) + + def list_objects( + self, + *, + limit: int | None, + prefix: str | None, + cursor: str | None, + mode: str | None, + token: str | None, + ) -> ListBlobResultType: + token = ensure_token(token) + resp = cast( + dict[str, Any], + iter_coroutine( + self._request_client.request_api( + "", + "GET", + token=token, + params=build_list_params(limit=limit, prefix=prefix, cursor=cursor, mode=mode), + ) + ), + ) + return build_list_blob_result(resp) + + def iter_objects( + self, + *, + prefix: str | None, + mode: str | None, + token: str | None, + batch_size: int | None, + limit: int | None, + cursor: str | None, + ) -> Iterator[ListBlobItem]: + next_cursor = cursor + yielded_count = 0 + + while True: + done, effective_limit = _resolve_page_limit( + batch_size=batch_size, + limit=limit, + yielded_count=yielded_count, + ) + if done: + break + + page = self.list_objects( + limit=effective_limit, + prefix=prefix, + cursor=next_cursor, + mode=mode, + token=token, + ) + + for item in page.blobs: + yield item + if limit is not None: + yielded_count += 1 + if yielded_count >= limit: + return + + next_cursor = _get_next_cursor(page) + if next_cursor is None: + break + + def _stream_download_chunks(self, response: httpx.Response) -> AsyncIterator[bytes]: + async def _iterate() -> AsyncIterator[bytes]: + for chunk in response.iter_bytes(): + yield chunk + + return _iterate() + + async def _close_download_response(self, response: httpx.Response) -> None: + await self._close_response(response) + + async def _close_response(self, response: httpx.Response) -> None: + response.close() + + def __enter__(self) -> SyncBlobOpsClient: + return self + + def __exit__(self, *args: object) -> None: + self.close() + + +class AsyncBlobOpsClient(BaseBlobOpsClient): + def __init__(self, *, timeout: float = 30.0) -> None: + request_client = create_async_request_client(timeout) + multipart_client = MultipartClient(request_client) + super().__init__( + request_client=request_client, + multipart_client=multipart_client, + multipart_runtime=create_async_multipart_upload_runtime(), + ) + + async def aclose(self) -> None: + await self._request_client.aclose() + + async def __aenter__(self) -> AsyncBlobOpsClient: + return self + + async def __aexit__(self, *args: object) -> None: + await self.aclose() + + def _make_upload_part_fn(self) -> Any: + return self._multipart_client.upload_part + + def _stream_download_chunks(self, response: httpx.Response) -> AsyncIterator[bytes]: + async def _iterate() -> AsyncIterator[bytes]: + async for chunk in response.aiter_bytes(): + yield chunk + + return _iterate() + + async def _close_download_response(self, response: httpx.Response) -> None: + await self._close_response(response) + + async def _close_response(self, response: httpx.Response) -> None: + await response.aclose() + + async def list_objects( + self, + *, + limit: int | None, + prefix: str | None, + cursor: str | None, + mode: str | None, + token: str | None, + ) -> ListBlobResultType: + token = ensure_token(token) + resp = cast( + dict[str, Any], + await self._request_client.request_api( + "", + "GET", + token=token, + params=build_list_params(limit=limit, prefix=prefix, cursor=cursor, mode=mode), + ), + ) + return build_list_blob_result(resp) + + async def iter_objects( + self, + *, + prefix: str | None, + mode: str | None, + token: str | None, + batch_size: int | None, + limit: int | None, + cursor: str | None, + ) -> AsyncIterator[ListBlobItem]: + next_cursor = cursor + yielded_count = 0 + + while True: + done, effective_limit = _resolve_page_limit( + batch_size=batch_size, + limit=limit, + yielded_count=yielded_count, + ) + if done: + break + + page = await self.list_objects( + limit=effective_limit, + prefix=prefix, + cursor=next_cursor, + mode=mode, + token=token, + ) + + for item in page.blobs: + yield item + if limit is not None: + yielded_count += 1 + if yielded_count >= limit: + return + + next_cursor = _get_next_cursor(page) + if next_cursor is None: + break + + +__all__ = [ + "AsyncBlobOpsClient", + "SyncBlobOpsClient", + "BlobRequestClient", + "create_sync_request_client", + "create_async_request_client", + "build_create_folder_result", + "build_head_blob_result", + "build_list_blob_result", + "build_list_params", + "build_put_blob_result", + "decode_blob_response", + "get_telemetry_size_bytes", + "is_network_error", + "map_blob_error", + "normalize_delete_urls", + "should_retry", +] diff --git a/src/vercel/_internal/blob/errors.py b/src/vercel/_internal/blob/errors.py new file mode 100644 index 0000000..1bdc68f --- /dev/null +++ b/src/vercel/_internal/blob/errors.py @@ -0,0 +1,93 @@ +class BlobError(Exception): + """Base class for Vercel Blob SDK errors.""" + + def __init__(self, message: str) -> None: + super().__init__(f"Vercel Blob: {message}") + + +class BlobAccessError(BlobError): + def __init__(self) -> None: + super().__init__("Access denied, please provide a valid token for this resource.") + + +class BlobNoTokenProvidedError(BlobError): + def __init__(self) -> None: + super().__init__( + "No token found. Either configure the `BLOB_READ_WRITE_TOKEN` " + "or `VERCEL_BLOB_READ_WRITE_TOKEN` environment variable, " + "or pass a `token` option to your calls." + ) + + +class BlobContentTypeNotAllowedError(BlobError): + def __init__(self, message: str) -> None: + super().__init__(f"Content type mismatch, {message}.") + + +class BlobPathnameMismatchError(BlobError): + def __init__(self, message: str) -> None: + super().__init__( + f"Pathname mismatch, {message}. " + "Check the pathname used in upload() or put() " + "matches the one from the client token." + ) + + +class BlobClientTokenExpiredError(BlobError): + def __init__(self) -> None: + super().__init__("Client token has expired.") + + +class BlobFileTooLargeError(BlobError): + def __init__(self, message: str) -> None: + super().__init__(f"File is too large, {message}.") + + +class BlobStoreNotFoundError(BlobError): + def __init__(self) -> None: + super().__init__("This store does not exist.") + + +class BlobStoreSuspendedError(BlobError): + def __init__(self) -> None: + super().__init__("This store has been suspended.") + + +class BlobUnknownError(BlobError): + def __init__(self) -> None: + super().__init__("Unknown error, please visit https://vercel.com/help.") + + +class BlobNotFoundError(BlobError): + def __init__(self) -> None: + super().__init__("The requested blob does not exist") + + +class BlobServiceNotAvailable(BlobError): + def __init__(self) -> None: + super().__init__("The blob service is currently not available. Please try again.") + + +class BlobServiceRateLimited(BlobError): + def __init__(self, seconds: int | None = None) -> None: + retry = f" - try again in {seconds} seconds" if seconds else "" + super().__init__( + f"Too many requests please lower the number of concurrent requests{retry}." + ) + self.retry_after: int = seconds or 0 + + +class BlobRequestAbortedError(BlobError): + def __init__(self) -> None: + super().__init__("The request was aborted.") + + +class BlobUnexpectedResponseContentTypeError(BlobError): + def __init__(self, content_type: str | None) -> None: + value = content_type or "" + super().__init__(f"Unexpected response content type: {value}.") + + +class BlobInvalidResponseJSONError(BlobError): + def __init__(self) -> None: + super().__init__("Failed to parse JSON response body.") diff --git a/src/vercel/_internal/blob/multipart.py b/src/vercel/_internal/blob/multipart.py new file mode 100644 index 0000000..4555146 --- /dev/null +++ b/src/vercel/_internal/blob/multipart.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +import inspect +import threading +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast +from urllib.parse import quote + +import anyio + +from vercel._internal.blob import ( + PutHeaders, + create_put_headers, + validate_access, +) +from vercel._internal.blob.errors import BlobError +from vercel._internal.blob.types import Access, UploadProgressEvent + +if TYPE_CHECKING: + from vercel._internal.blob.core import BlobRequestClient + +AsyncProgressCallback = ( + Callable[[UploadProgressEvent], None] | Callable[[UploadProgressEvent], Awaitable[None]] +) +SyncProgressCallback = Callable[[UploadProgressEvent], None] +SyncPartUploadFn = Callable[..., dict[str, Any]] +AsyncPartUploadFn = Callable[..., Awaitable[dict[str, Any]]] + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +DEFAULT_PART_SIZE = 8 * 1024 * 1024 # 8MB +MIN_PART_SIZE = 5 * 1024 * 1024 # 5 MiB minimum for most backends; last part may be smaller +MAX_CONCURRENCY = 6 + +# --------------------------------------------------------------------------- +# Multipart upload session dataclass +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MultipartUploadSession: + upload_id: str + key: str + path: str + headers: dict[str, str] + token: str | None + + +# --------------------------------------------------------------------------- +# Helper functions used by both uploader.py and _internal/blob/core.py +# --------------------------------------------------------------------------- + + +def validate_part_size(part_size: int) -> int: + ps = int(part_size) + if ps < MIN_PART_SIZE: + raise BlobError(f"part_size must be at least {MIN_PART_SIZE} bytes (5 MiB)") + return ps + + +def prepare_upload_headers( + *, + access: Access, + content_type: str | None, + add_random_suffix: bool, + overwrite: bool, + cache_control_max_age: int | None, +) -> dict[str, str]: + validate_access(access) + return cast( + dict[str, str], + create_put_headers( + content_type=content_type, + add_random_suffix=add_random_suffix, + allow_overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + access=access, + ), + ) + + +def _normalize_part_upload_result(part_number: int, response: dict[str, Any]) -> dict[str, Any]: + return {"partNumber": part_number, "etag": response["etag"]} + + +def order_uploaded_parts(parts: list[dict[str, Any]]) -> list[dict[str, Any]]: + ordered_parts = list(parts) + ordered_parts.sort(key=lambda part: int(part["partNumber"])) + return ordered_parts + + +def shape_complete_upload_result(response: dict[str, Any]) -> dict[str, Any]: + shaped = { + "url": response["url"], + "downloadUrl": response["downloadUrl"], + "pathname": response["pathname"], + "contentType": response["contentType"], + "contentDisposition": response["contentDisposition"], + } + for key, value in response.items(): + if key not in shaped: + shaped[key] = value + return shaped + + +def _aggregate_progress_event(loaded: int, total: int) -> UploadProgressEvent: + percentage = round((loaded / total) * 100, 2) if total else 0.0 + return UploadProgressEvent(loaded=loaded, total=total, percentage=percentage) + + +# --------------------------------------------------------------------------- +# Part-byte iterators +# --------------------------------------------------------------------------- + + +def _iter_part_bytes(body: Any, part_size: int) -> Iterator[bytes]: + # bytes-like + if isinstance(body, (bytes, bytearray, memoryview)): + view = memoryview(body) + offset = 0 + while offset < len(view): + end = min(offset + part_size, len(view)) + yield bytes(view[offset:end]) + offset = end + return + # str + if isinstance(body, str): + data = body.encode("utf-8") + view = memoryview(data) + offset = 0 + while offset < len(view): + end = min(offset + part_size, len(view)) + yield bytes(view[offset:end]) + offset = end + return + # file-like object + if hasattr(body, "read"): + while True: + chunk = body.read(part_size) # type: ignore[attr-defined] + if not chunk: + break + if not isinstance(chunk, (bytes, bytearray, memoryview)): + chunk = bytes(chunk) + yield bytes(chunk) + return + # Iterable[bytes] + if isinstance(body, Iterable): # type: ignore[arg-type] + buffer = bytearray() + for ch in body: # type: ignore[assignment] + if not isinstance(ch, (bytes, bytearray, memoryview)): + ch = bytes(ch) + buffer.extend(ch) + while len(buffer) >= part_size: + yield bytes(buffer[:part_size]) + del buffer[:part_size] + if buffer: + yield bytes(buffer) + return + # Fallback: coerce to bytes and slice + data = bytes(body) + view = memoryview(data) + offset = 0 + while offset < len(view): + end = min(offset + part_size, len(view)) + yield bytes(view[offset:end]) + offset = end + + +async def _aiter_part_bytes(body: Any, part_size: int) -> AsyncIterator[bytes]: + # AsyncIterable[bytes] + if hasattr(body, "__aiter__"): + buffer = bytearray() + async for ch in body: # type: ignore[misc] + if not isinstance(ch, (bytes, bytearray, memoryview)): + ch = bytes(ch) + buffer.extend(ch) + while len(buffer) >= part_size: + yield bytes(buffer[:part_size]) + del buffer[:part_size] + if buffer: + yield bytes(buffer) + return + # Delegate to sync iterator for other cases + for chunk in _iter_part_bytes(body, part_size): + yield chunk + + +# --------------------------------------------------------------------------- +# Upload runtime classes +# --------------------------------------------------------------------------- + + +class _SyncMultipartUploadRuntime: + def upload( + self, + *, + session: MultipartUploadSession, + body: Any, + part_size: int, + total: int, + on_upload_progress: SyncProgressCallback | None, + upload_part_fn: SyncPartUploadFn, + ) -> list[dict[str, Any]]: + loaded_per_part: dict[int, int] = {} + loaded_lock = threading.Lock() + results: list[dict[str, Any]] = [] + + def upload_one(part_number: int, content: bytes) -> dict[str, Any]: + def progress(evt: UploadProgressEvent) -> None: + with loaded_lock: + loaded_per_part[part_number] = int(evt.loaded) + if on_upload_progress: + loaded = sum(loaded_per_part.values()) + on_upload_progress(_aggregate_progress_event(loaded=loaded, total=total)) + + response = upload_part_fn( + upload_id=session.upload_id, + key=session.key, + path=session.path, + headers=session.headers, + token=session.token, + part_number=part_number, + body=content, + on_upload_progress=progress, + ) + return _normalize_part_upload_result(part_number, response) + + with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor: + inflight = set() + part_number = 1 + for chunk in _iter_part_bytes(body, part_size): + task = executor.submit(upload_one, part_number, chunk) + inflight.add(task) + part_number += 1 + if len(inflight) >= MAX_CONCURRENCY: + done, inflight = wait(inflight, return_when=FIRST_COMPLETED) + for completed in done: + results.append(completed.result()) + + if inflight: + done, _ = wait(inflight) + for completed in done: + results.append(completed.result()) + + if on_upload_progress: + on_upload_progress(UploadProgressEvent(loaded=total, total=total, percentage=100.0)) + + return results + + +class _AsyncMultipartUploadRuntime: + async def upload( + self, + *, + session: MultipartUploadSession, + body: Any, + part_size: int, + total: int, + on_upload_progress: AsyncProgressCallback | None, + upload_part_fn: AsyncPartUploadFn, + ) -> list[dict[str, Any]]: + loaded_per_part: dict[int, int] = {} + results: list[dict[str, Any]] = [] + + async def emit_progress(part_number: int, event: UploadProgressEvent) -> None: + loaded_per_part[part_number] = int(event.loaded) + if on_upload_progress: + loaded = sum(loaded_per_part.values()) + callback_result = on_upload_progress( + _aggregate_progress_event(loaded=loaded, total=total) + ) + if inspect.isawaitable(callback_result): + await cast(Awaitable[None], callback_result) + + def part_progress_callback( + part_number: int, + ) -> Callable[[UploadProgressEvent], Awaitable[None]]: + async def callback(event: UploadProgressEvent) -> None: + await emit_progress(part_number, event) + + return callback + + async def upload_one(part_number: int, content: bytes) -> dict[str, Any]: + response = await upload_part_fn( + upload_id=session.upload_id, + key=session.key, + path=session.path, + headers=session.headers, + part_number=part_number, + body=content, + on_upload_progress=part_progress_callback(part_number), + token=session.token, + ) + return _normalize_part_upload_result(part_number, response) + + semaphore = anyio.Semaphore(MAX_CONCURRENCY) + results_by_part: dict[int, dict[str, Any]] = {} + + async def run_limited_upload(part_number: int, content: bytes) -> None: + await semaphore.acquire() + try: + results_by_part[part_number] = await upload_one(part_number, content) + finally: + semaphore.release() + + part_number = 1 + async with anyio.create_task_group() as task_group: + async for chunk in _aiter_part_bytes(body, part_size): + task_group.start_soon(run_limited_upload, part_number, chunk) + part_number += 1 + + for ordered_part_number in sorted(results_by_part): + results.append(results_by_part[ordered_part_number]) + + if on_upload_progress: + loaded = sum(loaded_per_part.values()) + percentage = round((loaded / total) * 100, 2) if total else 100.0 + callback_result = on_upload_progress( + UploadProgressEvent(loaded=loaded, total=total, percentage=percentage) + ) + if inspect.isawaitable(callback_result): + await cast(Awaitable[None], callback_result) + + return results + + +def create_sync_multipart_upload_runtime() -> _SyncMultipartUploadRuntime: + return _SyncMultipartUploadRuntime() + + +def create_async_multipart_upload_runtime() -> _AsyncMultipartUploadRuntime: + return _AsyncMultipartUploadRuntime() + + +# --------------------------------------------------------------------------- +# Low-level multipart HTTP helpers (MPU header building & client classes) +# --------------------------------------------------------------------------- + + +def _build_headers( + headers: PutHeaders | dict[str, str], + *, + action: str, + key: str | None = None, + upload_id: str | None = None, + part_number: int | None = None, + set_json_content_type: bool = False, +) -> dict[str, str]: + request_headers = cast(dict[str, str], headers).copy() + if set_json_content_type: + request_headers["content-type"] = "application/json" + + request_headers["x-mpu-action"] = action + if key is not None: + request_headers["x-mpu-key"] = quote(key, safe="") + if upload_id is not None: + request_headers["x-mpu-upload-id"] = upload_id + if part_number is not None: + request_headers["x-mpu-part-number"] = str(part_number) + + return request_headers + + +class MultipartClient: + def __init__( + self, + request_client: BlobRequestClient, + ) -> None: + self._request_client = request_client + + async def _request_api(self, **kwargs: Any) -> Any: + return await self._request_client.request_api(**kwargs) + + async def create_multipart_upload( + self, + path: str, + headers: PutHeaders | dict[str, str], + *, + token: str | None = None, + ) -> dict[str, str]: + response = await self._request_api( + pathname="/mpu", + method="POST", + token=token, + headers=_build_headers(headers, action="create"), + params={"pathname": path}, + ) + return cast(dict[str, str], response) + + async def upload_part( + self, + *, + upload_id: str, + key: str, + path: str, + headers: PutHeaders | dict[str, str], + part_number: int, + body: Any, + on_upload_progress: AsyncProgressCallback | None = None, + token: str | None = None, + ) -> dict[str, Any]: + response = await self._request_api( + pathname="/mpu", + method="POST", + token=token, + headers=_build_headers( + headers, + action="upload", + key=key, + upload_id=upload_id, + part_number=part_number, + ), + params={"pathname": path}, + body=body, + on_upload_progress=on_upload_progress, + ) + return cast(dict[str, Any], response) + + async def complete_multipart_upload( + self, + *, + upload_id: str, + key: str, + path: str, + headers: PutHeaders | dict[str, str], + parts: list[dict[str, Any]], + token: str | None = None, + ) -> dict[str, Any]: + response = await self._request_api( + pathname="/mpu", + method="POST", + token=token, + headers=_build_headers( + headers, + action="complete", + key=key, + upload_id=upload_id, + set_json_content_type=True, + ), + params={"pathname": path}, + body=parts, + ) + return cast(dict[str, Any], response) + + +class SyncMultipartClient(MultipartClient): + def __init__(self) -> None: + from vercel._internal.blob.core import create_sync_request_client + + super().__init__(create_sync_request_client()) + + +class AsyncMultipartClient(MultipartClient): + def __init__(self) -> None: + from vercel._internal.blob.core import create_async_request_client + + super().__init__(create_async_request_client()) diff --git a/src/vercel/_internal/blob/types.py b/src/vercel/_internal/blob/types.py new file mode 100644 index 0000000..04feda6 --- /dev/null +++ b/src/vercel/_internal/blob/types.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import datetime +from typing import Literal + + +@dataclass(slots=True) +class PutBlobResult: + url: str + download_url: str + pathname: str + content_type: str + content_disposition: str + + +@dataclass(slots=True) +class HeadBlobResult: + size: int + uploaded_at: datetime + pathname: str + content_type: str + content_disposition: str + url: str + download_url: str + cache_control: str + + +@dataclass(slots=True) +class ListBlobItem: + url: str + download_url: str + pathname: str + size: int + uploaded_at: datetime + + +@dataclass(slots=True) +class ListBlobResult: + blobs: list[ListBlobItem] + cursor: str | None + has_more: bool + folders: list[str] | None = None + + +@dataclass(slots=True) +class CreateFolderResult: + pathname: str + url: str + + +@dataclass(slots=True) +class MultipartCreateResult: + upload_id: str + key: str + + +@dataclass(slots=True) +class GetBlobResult: + url: str + download_url: str + pathname: str + content_type: str | None + size: int | None + content_disposition: str + cache_control: str + uploaded_at: datetime + etag: str + content: bytes + status_code: int + + +@dataclass(slots=True) +class MultipartPart: + part_number: int + etag: str + + +Access = Literal["public", "private"] + + +@dataclass +class UploadProgressEvent: + loaded: int + total: int + percentage: float + + +OnUploadProgressCallback = ( + Callable[[UploadProgressEvent], None] | Callable[[UploadProgressEvent], Awaitable[None]] +) diff --git a/src/vercel/_internal/http/__init__.py b/src/vercel/_internal/http/__init__.py new file mode 100644 index 0000000..ebdb2d1 --- /dev/null +++ b/src/vercel/_internal/http/__init__.py @@ -0,0 +1,38 @@ +"""Shared HTTP infrastructure for Vercel API clients.""" + +from vercel._internal.http.clients import ( + create_base_async_client, + create_base_client, + create_headers_async_client, + create_headers_client, + create_vercel_async_client, + create_vercel_client, +) +from vercel._internal.http.config import DEFAULT_API_BASE_URL, DEFAULT_TIMEOUT +from vercel._internal.http.transport import ( + AsyncTransport, + BaseTransport, + BytesBody, + JSONBody, + RawBody, + RequestBody, + SyncTransport, +) + +__all__ = [ + "DEFAULT_API_BASE_URL", + "DEFAULT_TIMEOUT", + "BaseTransport", + "SyncTransport", + "AsyncTransport", + "JSONBody", + "BytesBody", + "RawBody", + "RequestBody", + "create_vercel_client", + "create_vercel_async_client", + "create_headers_client", + "create_headers_async_client", + "create_base_client", + "create_base_async_client", +] diff --git a/src/vercel/_internal/http/clients.py b/src/vercel/_internal/http/clients.py new file mode 100644 index 0000000..b654274 --- /dev/null +++ b/src/vercel/_internal/http/clients.py @@ -0,0 +1,209 @@ +"""Client factory functions for creating pre-configured httpx clients.""" + +import os +from collections.abc import Callable, Coroutine, Mapping, Sequence +from typing import Any + +import httpx + +from vercel._internal.http.config import DEFAULT_TIMEOUT + + +def _normalize_base_url(base_url: str) -> str: + return base_url.rstrip("/") + "/" + + +def _require_token(token: str | None) -> str: + env_token = os.getenv("VERCEL_TOKEN") + resolved = token or env_token + if not resolved: + raise RuntimeError("Missing Vercel API token. Pass token=... or set VERCEL_TOKEN.") + return resolved + + +def _create_vercel_auth_hook( + token: str, +) -> Callable[[httpx.Request], httpx.Request]: + def hook(request: httpx.Request) -> httpx.Request: + request.headers.setdefault("authorization", f"Bearer {token}") + request.headers.setdefault("accept", "application/json") + request.headers.setdefault("content-type", "application/json") + return request + + return hook + + +def _create_vercel_auth_hook_async( + token: str, +) -> Callable[[httpx.Request], Coroutine[Any, Any, None]]: + async def hook(request: httpx.Request) -> None: + request.headers.setdefault("authorization", f"Bearer {token}") + request.headers.setdefault("accept", "application/json") + request.headers.setdefault("content-type", "application/json") + + return hook + + +def _create_static_headers_hook( + headers: Mapping[str, str], +) -> Callable[[httpx.Request], httpx.Request]: + def hook(request: httpx.Request) -> httpx.Request: + for key, value in headers.items(): + request.headers.setdefault(key, value) + return request + + return hook + + +def _create_static_headers_hook_async( + headers: Mapping[str, str], +) -> Callable[[httpx.Request], Coroutine[Any, Any, None]]: + async def hook(request: httpx.Request) -> None: + for key, value in headers.items(): + request.headers.setdefault(key, value) + + return hook + + +SyncRequestHook = Callable[[httpx.Request], httpx.Request] +AsyncRequestHook = Callable[[httpx.Request], Coroutine[Any, Any, None]] + + +def _prepend_request_hooks( + client: httpx.Client | httpx.AsyncClient, + hooks: Sequence[SyncRequestHook | AsyncRequestHook], +) -> None: + existing_hooks = list(client.event_hooks.get("request", [])) + client.event_hooks["request"] = list(hooks) + existing_hooks + + +def create_vercel_client( + token: str | None = None, + timeout: float | None = None, + base_url: str | None = None, + *, + client: httpx.Client | None = None, +) -> httpx.Client: + """Create or configure a sync httpx client for Vercel API.""" + resolved_token = _require_token(token) + auth_hook = _create_vercel_auth_hook(resolved_token) + + if client is not None: + _prepend_request_hooks(client, [auth_hook]) + return client + + effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + kwargs: dict = { + "timeout": httpx.Timeout(effective_timeout), + "event_hooks": {"request": [auth_hook]}, + } + if base_url is not None: + kwargs["base_url"] = _normalize_base_url(base_url) + return httpx.Client(**kwargs) + + +def create_vercel_async_client( + token: str | None = None, + timeout: float | None = None, + base_url: str | None = None, + *, + client: httpx.AsyncClient | None = None, +) -> httpx.AsyncClient: + """Create or configure an async httpx client for Vercel API.""" + resolved_token = _require_token(token) + auth_hook = _create_vercel_auth_hook_async(resolved_token) + + if client is not None: + _prepend_request_hooks(client, [auth_hook]) + return client + + effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + kwargs: dict = { + "timeout": httpx.Timeout(effective_timeout), + "event_hooks": {"request": [auth_hook]}, + } + if base_url is not None: + kwargs["base_url"] = _normalize_base_url(base_url) + return httpx.AsyncClient(**kwargs) + + +def create_headers_client( + headers: Mapping[str, str], + timeout: float | None = None, + base_url: str | None = None, + *, + client: httpx.Client | None = None, +) -> httpx.Client: + """Create or configure a sync httpx client with static headers.""" + headers_hook = _create_static_headers_hook(headers) + + if client is not None: + _prepend_request_hooks(client, [headers_hook]) + return client + + effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + kwargs: dict = { + "timeout": httpx.Timeout(effective_timeout), + "event_hooks": {"request": [headers_hook]}, + } + if base_url is not None: + kwargs["base_url"] = _normalize_base_url(base_url) + return httpx.Client(**kwargs) + + +def create_headers_async_client( + headers: Mapping[str, str], + timeout: float | None = None, + base_url: str | None = None, + *, + client: httpx.AsyncClient | None = None, +) -> httpx.AsyncClient: + """Create or configure an async httpx client with static headers.""" + headers_hook = _create_static_headers_hook_async(headers) + + if client is not None: + _prepend_request_hooks(client, [headers_hook]) + return client + + effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + kwargs: dict = { + "timeout": httpx.Timeout(effective_timeout), + "event_hooks": {"request": [headers_hook]}, + } + if base_url is not None: + kwargs["base_url"] = _normalize_base_url(base_url) + return httpx.AsyncClient(**kwargs) + + +def create_base_client( + timeout: float | None = None, + base_url: str | None = None, +) -> httpx.Client: + """Create a sync httpx client without auth hooks.""" + effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + kwargs: dict = {"timeout": httpx.Timeout(effective_timeout)} + if base_url is not None: + kwargs["base_url"] = _normalize_base_url(base_url) + return httpx.Client(**kwargs) + + +def create_base_async_client( + timeout: float | None = None, + base_url: str | None = None, +) -> httpx.AsyncClient: + """Create an async httpx client without auth hooks.""" + effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + kwargs: dict = {"timeout": httpx.Timeout(effective_timeout)} + if base_url is not None: + kwargs["base_url"] = _normalize_base_url(base_url) + return httpx.AsyncClient(**kwargs) + + +__all__ = [ + "create_vercel_client", + "create_vercel_async_client", + "create_headers_client", + "create_headers_async_client", + "create_base_client", + "create_base_async_client", +] diff --git a/src/vercel/_internal/http/config.py b/src/vercel/_internal/http/config.py new file mode 100644 index 0000000..264e547 --- /dev/null +++ b/src/vercel/_internal/http/config.py @@ -0,0 +1,7 @@ +"""HTTP configuration constants for Vercel API clients.""" + +DEFAULT_API_BASE_URL = "https://api.vercel.com" +DEFAULT_TIMEOUT = 60.0 + + +__all__ = ["DEFAULT_API_BASE_URL", "DEFAULT_TIMEOUT"] diff --git a/src/vercel/_internal/http/transport.py b/src/vercel/_internal/http/transport.py new file mode 100644 index 0000000..d394b74 --- /dev/null +++ b/src/vercel/_internal/http/transport.py @@ -0,0 +1,163 @@ +"""HTTP transport implementations for sync and async clients.""" + +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import Any + +import httpx + + +def _normalize_path(path: str) -> str: + return path.lstrip("/") + + +@dataclass(frozen=True, slots=True) +class JSONBody: + data: Any + + +@dataclass(frozen=True, slots=True) +class BytesBody: + data: bytes + content_type: str = "application/octet-stream" + + +@dataclass(frozen=True, slots=True) +class RawBody: + """Unmodified request content (bytes, iterables, async iterables, file-like, etc.).""" + + data: Any + + +RequestBody = JSONBody | BytesBody | RawBody | None + + +def _build_request_kwargs( + *, + params: dict[str, Any] | None, + body: RequestBody, + headers: dict[str, str] | None, +) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + if params: + kwargs["params"] = params + + request_headers: dict[str, str] = {} + if headers: + request_headers.update(headers) + + if isinstance(body, JSONBody): + kwargs["json"] = body.data + elif isinstance(body, BytesBody): + kwargs["content"] = body.data + request_headers["Content-Type"] = body.content_type + elif isinstance(body, RawBody): + kwargs["content"] = body.data + + if request_headers: + kwargs["headers"] = request_headers + + return kwargs + + +class BaseTransport(abc.ABC): + @abc.abstractmethod + async def send( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + body: RequestBody = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + follow_redirects: bool | None = None, + stream: bool = False, + ) -> httpx.Response: + raise NotImplementedError + + +class SyncTransport(BaseTransport): + """Sync transport with async interface for use with iter_coroutine().""" + + def __init__(self, client: httpx.Client) -> None: + self._client = client + + async def send( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + body: RequestBody = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + follow_redirects: bool | None = None, + stream: bool = False, + ) -> httpx.Response: + kwargs = _build_request_kwargs( + params=params, + body=body, + headers=headers, + ) + + if timeout is not None: + kwargs["timeout"] = httpx.Timeout(timeout) + + request = self._client.build_request(method, _normalize_path(path), **kwargs) + send_kwargs: dict[str, Any] = {"stream": stream} + if follow_redirects is not None: + send_kwargs["follow_redirects"] = follow_redirects + return self._client.send(request, **send_kwargs) + + def close(self) -> None: + self._client.close() + + +class AsyncTransport(BaseTransport): + def __init__(self, client: httpx.AsyncClient) -> None: + self._client = client + + async def send( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + body: RequestBody = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + follow_redirects: bool | None = None, + stream: bool = False, + ) -> httpx.Response: + kwargs = _build_request_kwargs( + params=params, + body=body, + headers=headers, + ) + + if timeout is not None: + kwargs["timeout"] = httpx.Timeout(timeout) + + request = self._client.build_request(method, _normalize_path(path), **kwargs) + send_kwargs: dict[str, Any] = {"stream": stream} + if follow_redirects is not None: + send_kwargs["follow_redirects"] = follow_redirects + return await self._client.send(request, **send_kwargs) + + async def aclose(self) -> None: + await self._client.aclose() + + +__all__ = [ + "BaseTransport", + "SyncTransport", + "AsyncTransport", + "JSONBody", + "BytesBody", + "RawBody", + "RequestBody", +] diff --git a/src/vercel/_internal/iter_coroutine.py b/src/vercel/_internal/iter_coroutine.py new file mode 100644 index 0000000..624a734 --- /dev/null +++ b/src/vercel/_internal/iter_coroutine.py @@ -0,0 +1,20 @@ +"""iter_coroutine - Run simple coroutines synchronously.""" + +import typing + +_T = typing.TypeVar("_T") + + +def iter_coroutine(coro: typing.Coroutine[None, None, _T]) -> _T: + """Execute a non-suspending coroutine synchronously.""" + try: + coro.send(None) + except StopIteration as ex: + return ex.value # type: ignore [no-any-return] + else: + raise RuntimeError(f"coroutine {coro!r} did not stop after one iteration!") + finally: + coro.close() + + +__all__ = ["iter_coroutine"] diff --git a/src/vercel/_internal/telemetry/__init__.py b/src/vercel/_internal/telemetry/__init__.py new file mode 100644 index 0000000..b1a1543 --- /dev/null +++ b/src/vercel/_internal/telemetry/__init__.py @@ -0,0 +1,6 @@ +"""Telemetry functionality for Vercel Python SDK (internal use).""" + +from vercel._internal.telemetry.client import TelemetryClient +from vercel._internal.telemetry.tracker import track + +__all__ = ["TelemetryClient", "track"] diff --git a/src/vercel/_telemetry/client.py b/src/vercel/_internal/telemetry/client.py similarity index 98% rename from src/vercel/_telemetry/client.py rename to src/vercel/_internal/telemetry/client.py index e9c1ea9..397eb24 100644 --- a/src/vercel/_telemetry/client.py +++ b/src/vercel/_internal/telemetry/client.py @@ -8,7 +8,7 @@ import httpx -from .credentials import extract_credentials +from vercel._internal.telemetry.credentials import extract_credentials _TELEMETRY_ENABLED = os.getenv("VERCEL_TELEMETRY_DISABLED") != "1" _TELEMETRY_BRIDGE_URL = os.getenv( diff --git a/src/vercel/_telemetry/credentials.py b/src/vercel/_internal/telemetry/credentials.py similarity index 94% rename from src/vercel/_telemetry/credentials.py rename to src/vercel/_internal/telemetry/credentials.py index 6d7672f..cfc1098 100644 --- a/src/vercel/_telemetry/credentials.py +++ b/src/vercel/_internal/telemetry/credentials.py @@ -40,7 +40,7 @@ def extract_credentials( # Try to extract from OIDC token if available if token: try: - from ..oidc.token import decode_oidc_payload + from vercel.oidc.token import decode_oidc_payload payload = decode_oidc_payload(token) if not resolved_project_id: @@ -56,7 +56,7 @@ def extract_credentials( if not resolved_project_id or not resolved_team_id: try: # Import lazily to avoid hard dependency in all environments - from ..oidc.utils import find_project_info as _find_project_info # type: ignore + from vercel.oidc.utils import find_project_info as _find_project_info # type: ignore project_info = _find_project_info() if not resolved_project_id and project_info.get("projectId"): diff --git a/src/vercel/_telemetry/tracker.py b/src/vercel/_internal/telemetry/tracker.py similarity index 98% rename from src/vercel/_telemetry/tracker.py rename to src/vercel/_internal/telemetry/tracker.py index d5c5e8f..bb94d13 100644 --- a/src/vercel/_telemetry/tracker.py +++ b/src/vercel/_internal/telemetry/tracker.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar if TYPE_CHECKING: - from .client import TelemetryClient + from vercel._internal.telemetry.client import TelemetryClient # Singleton telemetry client instance with thread-safe initialization _telemetry_client = None @@ -34,7 +34,7 @@ def get_client() -> TelemetryClient | None: client = _telemetry_client if client is None: try: - from .client import TelemetryClient + from vercel._internal.telemetry.client import TelemetryClient _telemetry_client = TelemetryClient() except Exception: diff --git a/src/vercel/_telemetry/__init__.py b/src/vercel/_telemetry/__init__.py deleted file mode 100644 index 17978e2..0000000 --- a/src/vercel/_telemetry/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Telemetry functionality for Vercel Python SDK (internal use).""" - -from .client import TelemetryClient -from .tracker import track - -__all__ = ["TelemetryClient", "track"] diff --git a/src/vercel/blob/__init__.py b/src/vercel/blob/__init__.py index 3f78bf2..3e2c583 100644 --- a/src/vercel/blob/__init__.py +++ b/src/vercel/blob/__init__.py @@ -1,3 +1,5 @@ +from vercel._internal.blob import get_download_url + from . import aio as aioblob from .client import ( AsyncBlobClient, @@ -9,6 +11,7 @@ BlobContentTypeNotAllowedError, BlobError, BlobFileTooLargeError, + BlobInvalidResponseJSONError, BlobNotFoundError, BlobPathnameMismatchError, BlobRequestAbortedError, @@ -16,6 +19,7 @@ BlobServiceRateLimited, BlobStoreNotFoundError, BlobStoreSuspendedError, + BlobUnexpectedResponseContentTypeError, BlobUnknownError, ) from .multipart import ( @@ -55,6 +59,7 @@ upload_file_async, ) from .types import ( + Access, CreateFolderResult, GetBlobResult, HeadBlobResult, @@ -62,9 +67,10 @@ ListBlobResult, MultipartCreateResult, MultipartPart, + OnUploadProgressCallback, PutBlobResult, + UploadProgressEvent, ) -from .utils import Access, OnUploadProgressCallback, UploadProgressEvent, get_download_url __all__ = [ # errors @@ -78,6 +84,8 @@ "BlobStoreSuspendedError", "BlobUnknownError", "BlobNotFoundError", + "BlobUnexpectedResponseContentTypeError", + "BlobInvalidResponseJSONError", "BlobServiceNotAvailable", "BlobServiceRateLimited", "BlobRequestAbortedError", diff --git a/src/vercel/blob/aio.py b/src/vercel/blob/aio.py index 9868b84..0b8a49e 100644 --- a/src/vercel/blob/aio.py +++ b/src/vercel/blob/aio.py @@ -1,3 +1,5 @@ +from vercel._internal.blob import get_download_url + from .client import ( AsyncBlobClient, ) @@ -35,14 +37,16 @@ upload_file_async as upload_file, ) from .types import ( + Access, CreateFolderResult, GetBlobResult, HeadBlobResult, ListBlobItem, ListBlobResult, + OnUploadProgressCallback, PutBlobResult, + UploadProgressEvent, ) -from .utils import Access, OnUploadProgressCallback, UploadProgressEvent, get_download_url __all__ = [ # errors diff --git a/src/vercel/blob/api.py b/src/vercel/blob/api.py deleted file mode 100644 index 50beabd..0000000 --- a/src/vercel/blob/api.py +++ /dev/null @@ -1,325 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from collections.abc import Awaitable, Callable -from typing import Any, cast - -import httpx - -from .errors import ( - BlobAccessError, - BlobClientTokenExpiredError, - BlobContentTypeNotAllowedError, - BlobError, - BlobFileTooLargeError, - BlobNotFoundError, - BlobPathnameMismatchError, - BlobServiceNotAvailable, - BlobServiceRateLimited, - BlobStoreNotFoundError, - BlobStoreSuspendedError, - BlobUnknownError, -) -from .utils import ( - PutHeaders, - StreamingBodyWithProgress, - UploadProgressEvent, - compute_body_length, - debug, - ensure_token, - extract_store_id_from_token, - get_api_url, - get_api_version, - get_proxy_through_alternative_api_header_from_env, - get_retries, - make_request_id, - parse_rfc7231_retry_after, - should_use_x_content_length, -) - - -def _map_error(response: httpx.Response) -> tuple[str, BlobError]: - try: - data = response.json() - except Exception: - data = {} - - code = (data.get("error") or {}).get("code") or "unknown_error" - message = (data.get("error") or {}).get("message") or "" - - # Heuristics mirroring TS SDK: https://github.com/vercel/storage/blob/main/packages/blob/src/api.ts - if "contentType" in message and "is not allowed" in message: - code = "content_type_not_allowed" - if '"pathname"' in message and "does not match the token payload" in message: - code = "client_token_pathname_mismatch" - if message == "Token expired": - code = "client_token_expired" - if "the file length cannot be greater than" in message: - code = "file_too_large" - - if code == "store_suspended": - return code, BlobStoreSuspendedError() - if code == "forbidden": - return code, BlobAccessError() - if code == "content_type_not_allowed": - return code, BlobContentTypeNotAllowedError(message or "") - if code == "client_token_pathname_mismatch": - return code, BlobPathnameMismatchError(message or "") - if code == "client_token_expired": - return code, BlobClientTokenExpiredError() - if code == "file_too_large": - return code, BlobFileTooLargeError(message or "") - if code == "not_found": - return code, BlobNotFoundError() - if code == "store_not_found": - return code, BlobStoreNotFoundError() - if code == "bad_request": - return code, BlobError(message or "Bad request") - if code == "service_unavailable": - return code, BlobServiceNotAvailable() - if code == "rate_limited": - seconds = parse_rfc7231_retry_after(response.headers.get("retry-after")) - return code, BlobServiceRateLimited(seconds) - - return code, BlobUnknownError() - - -def _should_retry(code: str) -> bool: - return code in {"unknown_error", "service_unavailable", "internal_server_error"} - - -def _is_network_error(exc: Exception) -> bool: - return isinstance(exc, httpx.TransportError) - - -def request_api( - pathname: str, - method: str, - *, - token: str | None = None, - headers: PutHeaders | dict[str, str] | None = None, - params: dict[str, Any] | None = None, - body: Any = None, - on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, - timeout: float | None = None, -) -> Any: - """Synchronous HTTP caller with retries, headers, progress and error mapping.""" - token = ensure_token(token) - store_id = extract_store_id_from_token(token) - request_id = make_request_id(store_id) - attempt = 0 - retries = get_retries() - api_version = get_api_version() - extra_headers = get_proxy_through_alternative_api_header_from_env() - headers = cast(dict[str, str], headers or {}) - - send_body_length = bool(on_upload_progress) or should_use_x_content_length() - total_length = compute_body_length(body) if send_body_length else 0 - - if on_upload_progress: - on_upload_progress(UploadProgressEvent(loaded=0, total=total_length, percentage=0.0)) - - url = get_api_url(pathname) - timeout_conf = httpx.Timeout(timeout) if timeout is not None else httpx.Timeout(30.0) - - with httpx.Client(timeout=timeout_conf) as client: - for attempt in range(retries + 1): - try: - final_headers = { - "authorization": f"Bearer {token}", - "x-api-blob-request-id": request_id, - "x-api-blob-request-attempt": str(attempt), - "x-api-version": api_version, - **extra_headers, - } - if headers: - final_headers.update(headers) - if send_body_length and total_length: - final_headers["x-content-length"] = str(total_length) - - json_body = None - content = None - - if body is not None: - if isinstance(body, (bytes, bytearray, memoryview, str)) or hasattr( - body, "read" - ): - content = StreamingBodyWithProgress(body, on_upload_progress) - else: - json_body = body - - resp = client.request( - method=method, - url=url, - headers=final_headers, - params=params, - content=content, - json=json_body, - ) - - if 200 <= resp.status_code < 300: - if on_upload_progress: - on_upload_progress( - UploadProgressEvent( - loaded=total_length or 0, - total=total_length or 0, - percentage=100.0, - ) - ) - content_type = resp.headers.get("content-type", "") - if "application/json" in content_type or (resp.text or "").startswith("{"): - try: - return resp.json() - except Exception: - return resp.text - try: - return resp.json() - except Exception: - return resp.text - - code, mapped = _map_error(resp) - if _should_retry(code) and attempt < retries: - debug(f"retrying API request to {pathname}", f"{code}") - time.sleep(min(2**attempt * 0.1, 2.0)) - continue - raise mapped - - except Exception as exc: - if _is_network_error(exc) and attempt < retries: - debug(f"retrying API request to {pathname}", str(exc)) - time.sleep(min(2**attempt * 0.1, 2.0)) - continue - if isinstance(exc, httpx.HTTPError): - raise BlobUnknownError() from exc - raise - - raise BlobUnknownError() - - -async def request_api_async( - pathname: str, - method: str, - *, - token: str | None = None, - headers: PutHeaders | dict[str, str] | None = None, - params: dict[str, Any] | None = None, - body: Any = None, - on_upload_progress: ( - Callable[[UploadProgressEvent], None] - | Callable[[UploadProgressEvent], Awaitable[None]] - | None - ) = None, - timeout: float | None = None, -) -> Any: - """Core HTTP caller with retries, headers, progress and error mapping.""" - token = ensure_token(token) - store_id = extract_store_id_from_token(token) - request_id = make_request_id(store_id) - attempt = 0 - retries = get_retries() - api_version = get_api_version() - extra_headers = get_proxy_through_alternative_api_header_from_env() - headers = cast(dict[str, str], headers or {}) - send_body_length = bool(on_upload_progress) or should_use_x_content_length() - total_length = compute_body_length(body) if send_body_length else 0 - - if on_upload_progress: - result = on_upload_progress( - UploadProgressEvent(loaded=0, total=total_length, percentage=0.0) - ) - # If callback is async, await it - if asyncio.iscoroutine(result): - await result - - url = get_api_url(pathname) - timeout_conf = httpx.Timeout(timeout) if timeout is not None else httpx.Timeout(30.0) - async with httpx.AsyncClient(timeout=timeout_conf) as client: - for attempt in range(retries + 1): - try: - final_headers = { - "authorization": f"Bearer {token}", - "x-api-blob-request-id": request_id, - "x-api-blob-request-attempt": str(attempt), - "x-api-version": api_version, - **extra_headers, - } - if headers: - final_headers.update(headers) - if send_body_length and total_length: - final_headers["x-content-length"] = str(total_length) - - json_body = None - content = None - - # Wrap body for progress when possible - if body is not None: - if isinstance(body, (bytes, bytearray, memoryview, str)) or hasattr( - body, "read" - ): - wrapped = StreamingBodyWithProgress(body, on_upload_progress) - # For AsyncClient, ensure async streaming content to avoid sync-body error - content = wrapped.__aiter__() - else: - # For objects meant to be JSON - json_body = body - - if content is not None: - resp = await client.request( - method=method, - url=url, - headers=final_headers, - params=params, - content=content, - json=json_body, - ) - else: - resp = await client.request( - method=method, - url=url, - headers=final_headers, - params=params, - json=json_body, - ) - - if 200 <= resp.status_code < 300: - if on_upload_progress: - result = on_upload_progress( - UploadProgressEvent( - loaded=total_length or 0, - total=total_length or 0, - percentage=100.0, - ) - ) - # If callback is async, await it - if asyncio.iscoroutine(result): - await result - content_type = resp.headers.get("content-type", "") - if "application/json" in content_type or resp.text.startswith("{"): - try: - return resp.json() - except Exception: - return resp.text - try: - return resp.json() - except Exception: - return resp.text - - code, mapped = _map_error(resp) - if _should_retry(code) and attempt < retries: - debug(f"retrying API request to {pathname}", f"{code}") - await asyncio.sleep(min(2**attempt * 0.1, 2.0)) - continue - raise mapped - - except Exception as exc: - # If it's an httpx transport error, treat as network and retry; else raise - if _is_network_error(exc) and attempt < retries: - debug(f"retrying API request to {pathname}", str(exc)) - await asyncio.sleep(min(2**attempt * 0.1, 2.0)) - continue - if isinstance(exc, httpx.HTTPError): - raise BlobUnknownError() from exc - raise - - raise BlobUnknownError() diff --git a/src/vercel/blob/client.py b/src/vercel/blob/client.py index 037a87e..85e3a9e 100644 --- a/src/vercel/blob/client.py +++ b/src/vercel/blob/client.py @@ -1,52 +1,65 @@ from __future__ import annotations import os -from collections.abc import Awaitable, Callable, Iterable, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator from os import PathLike from typing import Any -from .errors import BlobNoTokenProvidedError -from .multipart.api import create_multipart_uploader, create_multipart_uploader_async -from .ops import ( - copy, - copy_async, - create_folder, - create_folder_async, - delete, - delete_async, - download_file, - download_file_async, - get, - get_async, - head, - head_async, - iter_objects, - iter_objects_async, - list_objects, - list_objects_async, - put, - put_async, - upload_file, - upload_file_async, +from vercel._internal.blob import ensure_token +from vercel._internal.blob.core import ( + AsyncBlobOpsClient, + SyncBlobOpsClient, + normalize_delete_urls, ) -from .types import ( +from vercel._internal.blob.multipart import MultipartClient +from vercel._internal.iter_coroutine import iter_coroutine +from vercel.blob.errors import BlobError, BlobNoTokenProvidedError +from vercel.blob.multipart.api import ( + AsyncMultipartUploader, + MultipartUploader, + create_multipart_uploader, + create_multipart_uploader_async, +) +from vercel.blob.types import ( + Access, CreateFolderResult as CreateFolderResultType, GetBlobResult as GetBlobResultType, HeadBlobResult as HeadBlobResultType, ListBlobItem, ListBlobResult as ListBlobResultType, PutBlobResult as PutBlobResultType, + UploadProgressEvent, ) -from .utils import Access, UploadProgressEvent class BlobClient: def __init__(self, token: str | None = None): - self.token = ( + resolved_token = ( token or os.getenv("BLOB_READ_WRITE_TOKEN") or os.getenv("VERCEL_BLOB_READ_WRITE_TOKEN") ) - if not self.token: + if not resolved_token: raise BlobNoTokenProvidedError() + self.token = ensure_token(resolved_token) + + self._ops_client = SyncBlobOpsClient() + self._closed = False + + def _ensure_open(self) -> None: + if self._closed: + raise BlobError("Client is closed") + + def close(self) -> None: + if self._closed: + return + self._ops_client.close() + self._closed = True + + def __enter__(self) -> BlobClient: + self._ensure_open() + return self + + def __exit__(self, *args: object) -> None: + self.close() def put( self, @@ -61,18 +74,22 @@ def put( multipart: bool = False, on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, ) -> PutBlobResultType: - return put( - path=path, - body=body, - access=access, - content_type=content_type, - add_random_suffix=add_random_suffix, - overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - token=self.token, - multipart=multipart, - on_upload_progress=on_upload_progress, + self._ensure_open() + result, _ = iter_coroutine( + self._ops_client.put_blob( + path, + body, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=self.token, + multipart=multipart, + on_upload_progress=on_upload_progress, + ) ) + return result def get( self, @@ -83,20 +100,37 @@ def get( use_cache: bool = True, if_none_match: str | None = None, ) -> GetBlobResultType: - return get( - url_or_path, - access=access, - token=self.token, - timeout=timeout, - use_cache=use_cache, - if_none_match=if_none_match, + self._ensure_open() + return iter_coroutine( + self._ops_client.get_blob( + url_or_path, + access=access, + token=self.token, + timeout=timeout, + use_cache=use_cache, + if_none_match=if_none_match, + default_timeout=30.0, + ) ) def head(self, url_or_path: str) -> HeadBlobResultType: - return head(url_or_path, token=self.token) + self._ensure_open() + return iter_coroutine( + self._ops_client.head_blob( + url_or_path, + token=self.token, + ) + ) def delete(self, url_or_path: str | Iterable[str]) -> None: - return delete(url_or_path, token=self.token) + self._ensure_open() + normalized_urls = normalize_delete_urls(url_or_path) + iter_coroutine( + self._ops_client.delete_blob( + normalized_urls, + token=self.token, + ) + ) def list_objects( self, @@ -106,7 +140,14 @@ def list_objects( cursor: str | None = None, mode: str | None = None, ) -> ListBlobResultType: - return list_objects(limit=limit, prefix=prefix, cursor=cursor, mode=mode, token=self.token) + self._ensure_open() + return self._ops_client.list_objects( + limit=limit, + prefix=prefix, + cursor=cursor, + mode=mode, + token=self.token, + ) def iter_objects( self, @@ -117,7 +158,8 @@ def iter_objects( limit: int | None = None, cursor: str | None = None, ) -> Iterator[ListBlobItem]: - return iter_objects( + self._ensure_open() + return self._ops_client.iter_objects( prefix=prefix, mode=mode, token=self.token, @@ -137,19 +179,29 @@ def copy( overwrite: bool = False, cache_control_max_age: int | None = None, ) -> PutBlobResultType: - return copy( - src_path, - dst_path, - access=access, - content_type=content_type, - add_random_suffix=add_random_suffix, - overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - token=self.token, + self._ensure_open() + return iter_coroutine( + self._ops_client.copy_blob( + src_path, + dst_path, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=self.token, + ) ) def create_folder(self, path: str, *, overwrite: bool = False) -> CreateFolderResultType: - return create_folder(path, token=self.token, overwrite=overwrite) + self._ensure_open() + return iter_coroutine( + self._ops_client.create_folder( + path, + token=self.token, + overwrite=overwrite, + ) + ) def download_file( self, @@ -162,15 +214,18 @@ def download_file( create_parents: bool = True, progress: Callable[[int, int | None], None] | None = None, ) -> str: - return download_file( - url_or_path, - local_path, - access=access, - token=self.token, - timeout=timeout, - overwrite=overwrite, - create_parents=create_parents, - progress=progress, + self._ensure_open() + return iter_coroutine( + self._ops_client.download_file( + url_or_path, + local_path, + access=access, + token=self.token, + timeout=timeout, + overwrite=overwrite, + create_parents=create_parents, + progress=progress, + ) ) def upload_file( @@ -186,17 +241,21 @@ def upload_file( multipart: bool = False, on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, ) -> PutBlobResultType: - return upload_file( - local_path, - path, - access=access, - content_type=content_type, - add_random_suffix=add_random_suffix, - overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - token=self.token, - multipart=multipart, - on_upload_progress=on_upload_progress, + self._ensure_open() + return iter_coroutine( + self._ops_client.upload_file( + local_path, + path, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=self.token, + multipart=multipart, + on_upload_progress=on_upload_progress, + missing_local_path_error="src_path is required", + ) ) def create_multipart_uploader( @@ -208,8 +267,9 @@ def create_multipart_uploader( add_random_suffix: bool = True, overwrite: bool = False, cache_control_max_age: int | None = None, - ): + ) -> MultipartUploader: """Create a multipart uploader bound to this client's token.""" + self._ensure_open() return create_multipart_uploader( path, access=access, @@ -218,16 +278,38 @@ def create_multipart_uploader( overwrite=overwrite, cache_control_max_age=cache_control_max_age, token=self.token, + multipart_client=MultipartClient(self._ops_client._request_client), ) class AsyncBlobClient: def __init__(self, token: str | None = None): - self.token = ( + resolved_token = ( token or os.getenv("BLOB_READ_WRITE_TOKEN") or os.getenv("VERCEL_BLOB_READ_WRITE_TOKEN") ) - if not self.token: + if not resolved_token: raise BlobNoTokenProvidedError() + self.token = ensure_token(resolved_token) + + self._ops_client = AsyncBlobOpsClient() + self._closed = False + + def _ensure_open(self) -> None: + if self._closed: + raise BlobError("Client is closed") + + async def aclose(self) -> None: + if self._closed: + return + await self._ops_client.aclose() + self._closed = True + + async def __aenter__(self) -> AsyncBlobClient: + self._ensure_open() + return self + + async def __aexit__(self, *args: object) -> None: + await self.aclose() async def put( self, @@ -246,9 +328,10 @@ async def put( | None ) = None, ) -> PutBlobResultType: - return await put_async( - path=path, - body=body, + self._ensure_open() + result, _ = await self._ops_client.put_blob( + path, + body, access=access, content_type=content_type, add_random_suffix=add_random_suffix, @@ -258,6 +341,7 @@ async def put( multipart=multipart, on_upload_progress=on_upload_progress, ) + return result async def get( self, @@ -268,20 +352,31 @@ async def get( use_cache: bool = True, if_none_match: str | None = None, ) -> GetBlobResultType: - return await get_async( + self._ensure_open() + return await self._ops_client.get_blob( url_or_path, access=access, token=self.token, timeout=timeout, use_cache=use_cache, if_none_match=if_none_match, + default_timeout=30.0, ) async def head(self, url_or_path: str) -> HeadBlobResultType: - return await head_async(url_or_path, token=self.token) + self._ensure_open() + return await self._ops_client.head_blob( + url_or_path, + token=self.token, + ) async def delete(self, url_or_path: str | Iterable[str]) -> None: - return await delete_async(url_or_path, token=self.token) + self._ensure_open() + normalized_urls = normalize_delete_urls(url_or_path) + await self._ops_client.delete_blob( + normalized_urls, + token=self.token, + ) async def iter_objects( self, @@ -291,8 +386,9 @@ async def iter_objects( batch_size: int | None = None, limit: int | None = None, cursor: str | None = None, - ): - return iter_objects_async( + ) -> AsyncIterator[ListBlobItem]: + self._ensure_open() + return self._ops_client.iter_objects( prefix=prefix, mode=mode, token=self.token, @@ -309,12 +405,22 @@ async def list_objects( cursor: str | None = None, mode: str | None = None, ) -> ListBlobResultType: - return await list_objects_async( - limit=limit, prefix=prefix, cursor=cursor, mode=mode, token=self.token + self._ensure_open() + return await self._ops_client.list_objects( + limit=limit, + prefix=prefix, + cursor=cursor, + mode=mode, + token=self.token, ) async def create_folder(self, path: str, *, overwrite: bool = False) -> CreateFolderResultType: - return await create_folder_async(path, token=self.token, overwrite=overwrite) + self._ensure_open() + return await self._ops_client.create_folder( + path, + token=self.token, + overwrite=overwrite, + ) async def copy( self, @@ -327,7 +433,8 @@ async def copy( overwrite: bool = False, cache_control_max_age: int | None = None, ) -> PutBlobResultType: - return await copy_async( + self._ensure_open() + return await self._ops_client.copy_blob( src_path, dst_path, access=access, @@ -351,7 +458,8 @@ async def download_file( Callable[[int, int | None], None] | Callable[[int, int | None], Awaitable[None]] | None ) = None, ) -> str: - return await download_file_async( + self._ensure_open() + return await self._ops_client.download_file( url_or_path, local_path, access=access, @@ -379,7 +487,8 @@ async def upload_file( | None ) = None, ) -> PutBlobResultType: - return await upload_file_async( + self._ensure_open() + return await self._ops_client.upload_file( local_path, path, access=access, @@ -390,6 +499,7 @@ async def upload_file( token=self.token, multipart=multipart, on_upload_progress=on_upload_progress, + missing_local_path_error="local_path is required", ) async def create_multipart_uploader( @@ -401,8 +511,9 @@ async def create_multipart_uploader( add_random_suffix: bool = True, overwrite: bool = False, cache_control_max_age: int | None = None, - ): + ) -> AsyncMultipartUploader: """Create an async multipart uploader bound to this client's token.""" + self._ensure_open() return await create_multipart_uploader_async( path, access=access, @@ -411,4 +522,5 @@ async def create_multipart_uploader( overwrite=overwrite, cache_control_max_age=cache_control_max_age, token=self.token, + multipart_client=MultipartClient(self._ops_client._request_client), ) diff --git a/src/vercel/blob/errors.py b/src/vercel/blob/errors.py index b7bb3e2..41fa0cc 100644 --- a/src/vercel/blob/errors.py +++ b/src/vercel/blob/errors.py @@ -1,82 +1,37 @@ -class BlobError(Exception): - """Base class for Vercel Blob SDK errors.""" - - def __init__(self, message: str) -> None: - super().__init__(f"Vercel Blob: {message}") - - -class BlobAccessError(BlobError): - def __init__(self) -> None: - super().__init__("Access denied, please provide a valid token for this resource.") - - -class BlobNoTokenProvidedError(BlobError): - def __init__(self) -> None: - super().__init__( - "No token found. Either configure the `BLOB_READ_WRITE_TOKEN` " - "or `VERCEL_BLOB_READ_WRITE_TOKEN` environment variable, " - "or pass a `token` option to your calls." - ) - - -class BlobContentTypeNotAllowedError(BlobError): - def __init__(self, message: str) -> None: - super().__init__(f"Content type mismatch, {message}.") - - -class BlobPathnameMismatchError(BlobError): - def __init__(self, message: str) -> None: - super().__init__( - f"Pathname mismatch, {message}. " - "Check the pathname used in upload() or put() " - "matches the one from the client token." - ) - - -class BlobClientTokenExpiredError(BlobError): - def __init__(self) -> None: - super().__init__("Client token has expired.") - - -class BlobFileTooLargeError(BlobError): - def __init__(self, message: str) -> None: - super().__init__(f"File is too large, {message}.") - - -class BlobStoreNotFoundError(BlobError): - def __init__(self) -> None: - super().__init__("This store does not exist.") - - -class BlobStoreSuspendedError(BlobError): - def __init__(self) -> None: - super().__init__("This store has been suspended.") - - -class BlobUnknownError(BlobError): - def __init__(self) -> None: - super().__init__("Unknown error, please visit https://vercel.com/help.") - - -class BlobNotFoundError(BlobError): - def __init__(self) -> None: - super().__init__("The requested blob does not exist") - - -class BlobServiceNotAvailable(BlobError): - def __init__(self) -> None: - super().__init__("The blob service is currently not available. Please try again.") - - -class BlobServiceRateLimited(BlobError): - def __init__(self, seconds: int | None = None) -> None: - retry = f" - try again in {seconds} seconds" if seconds else "" - super().__init__( - f"Too many requests please lower the number of concurrent requests{retry}." - ) - self.retry_after: int = seconds or 0 - - -class BlobRequestAbortedError(BlobError): - def __init__(self) -> None: - super().__init__("The request was aborted.") +from vercel._internal.blob.errors import ( + BlobAccessError, + BlobClientTokenExpiredError, + BlobContentTypeNotAllowedError, + BlobError, + BlobFileTooLargeError, + BlobInvalidResponseJSONError, + BlobNotFoundError, + BlobNoTokenProvidedError, + BlobPathnameMismatchError, + BlobRequestAbortedError, + BlobServiceNotAvailable, + BlobServiceRateLimited, + BlobStoreNotFoundError, + BlobStoreSuspendedError, + BlobUnexpectedResponseContentTypeError, + BlobUnknownError, +) + +__all__ = [ + "BlobAccessError", + "BlobClientTokenExpiredError", + "BlobContentTypeNotAllowedError", + "BlobError", + "BlobFileTooLargeError", + "BlobInvalidResponseJSONError", + "BlobNoTokenProvidedError", + "BlobNotFoundError", + "BlobPathnameMismatchError", + "BlobRequestAbortedError", + "BlobServiceNotAvailable", + "BlobServiceRateLimited", + "BlobStoreNotFoundError", + "BlobStoreSuspendedError", + "BlobUnexpectedResponseContentTypeError", + "BlobUnknownError", +] diff --git a/src/vercel/blob/multipart/api.py b/src/vercel/blob/multipart/api.py index a5a9db5..259407c 100644 --- a/src/vercel/blob/multipart/api.py +++ b/src/vercel/blob/multipart/api.py @@ -4,27 +4,89 @@ from collections.abc import Awaitable, Callable from typing import Any, cast -from ..errors import BlobError -from ..types import MultipartCreateResult, MultipartPart, PutBlobResult -from ..utils import ( - Access, +from vercel._internal.blob import ( PutHeaders, - UploadProgressEvent, create_put_headers, ensure_token, validate_access, validate_path, ) -from .core import ( - call_complete_multipart_upload, - call_complete_multipart_upload_async, - call_create_multipart_upload, - call_create_multipart_upload_async, - call_upload_part, - call_upload_part_async, +from vercel._internal.blob.multipart import ( + AsyncMultipartClient, + MultipartClient, + SyncMultipartClient, +) +from vercel._internal.iter_coroutine import iter_coroutine +from vercel.blob.errors import BlobError +from vercel.blob.types import ( + Access, + MultipartCreateResult, + MultipartPart, + PutBlobResult, + UploadProgressEvent, ) +def _validate_multipart_context(path: str, access: Access, token: str | None) -> str: + resolved_token = ensure_token(token) + validate_path(path) + validate_access(access) + return resolved_token + + +def _build_put_headers( + *, + access: Access, + content_type: str | None = None, + add_random_suffix: bool = False, + overwrite: bool = False, + cache_control_max_age: int | None = None, +) -> dict[str, str]: + return cast( + dict[str, str], + create_put_headers( + content_type=content_type, + add_random_suffix=add_random_suffix, + allow_overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + access=access, + ), + ) + + +def _build_multipart_create_result(response: dict[str, Any]) -> MultipartCreateResult: + return MultipartCreateResult(upload_id=response["uploadId"], key=response["key"]) + + +def _build_multipart_part_result(part_number: int, response: dict[str, Any]) -> MultipartPart: + return MultipartPart(part_number=part_number, etag=response["etag"]) + + +def _build_put_blob_result(response: dict[str, Any]) -> PutBlobResult: + return PutBlobResult( + url=response["url"], + download_url=response["downloadUrl"], + pathname=response["pathname"], + content_type=response["contentType"], + content_disposition=response["contentDisposition"], + ) + + +def _normalize_complete_parts(parts: list[MultipartPart]) -> list[dict[str, Any]]: + return [{"partNumber": part.part_number, "etag": part.etag} for part in parts] + + +def _validate_part_upload_inputs(part_number: int, body: Any) -> None: + if part_number < 1 or part_number > 10000: + raise BlobError("part_number must be between 1 and 10,000") + + if isinstance(body, dict) and not hasattr(body, "read"): + raise BlobError( + "Body must be a string, bytes, or file-like object. " + "You sent a plain dictionary, double check what you're trying to upload." + ) + + def create_multipart_upload( path: str, *, @@ -35,20 +97,18 @@ def create_multipart_upload( cache_control_max_age: int | None = None, token: str | None = None, ) -> MultipartCreateResult: - token = ensure_token(token) - validate_path(path) - validate_access(access) - - headers = create_put_headers( + resolved_token = _validate_multipart_context(path, access, token) + headers = _build_put_headers( + access=access, content_type=content_type, add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, + overwrite=overwrite, cache_control_max_age=cache_control_max_age, - access=access, ) - - resp = call_create_multipart_upload(path, headers, token=token) - return MultipartCreateResult(upload_id=resp["uploadId"], key=resp["key"]) + response = iter_coroutine( + SyncMultipartClient().create_multipart_upload(path, headers, token=resolved_token) + ) + return _build_multipart_create_result(response) async def create_multipart_upload_async( @@ -61,18 +121,18 @@ async def create_multipart_upload_async( cache_control_max_age: int | None = None, token: str | None = None, ) -> MultipartCreateResult: - token = ensure_token(token) - validate_path(path) - validate_access(access) - headers = create_put_headers( + resolved_token = _validate_multipart_context(path, access, token) + headers = _build_put_headers( + access=access, content_type=content_type, add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, + overwrite=overwrite, cache_control_max_age=cache_control_max_age, - access=access, ) - resp = await call_create_multipart_upload_async(path, headers, token=token) - return MultipartCreateResult(upload_id=resp["uploadId"], key=resp["key"]) + response = await AsyncMultipartClient().create_multipart_upload( + path, headers, token=resolved_token + ) + return _build_multipart_create_result(response) def upload_part( @@ -87,22 +147,22 @@ def upload_part( content_type: str | None = None, on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, ) -> MultipartPart: - token = ensure_token(token) - validate_path(path) - validate_access(access) - - headers = create_put_headers(content_type=content_type, access=access) - resp = call_upload_part( - upload_id=upload_id, - key=key, - path=path, - headers=headers, - token=token, - part_number=part_number, - body=body, - on_upload_progress=on_upload_progress, + resolved_token = _validate_multipart_context(path, access, token) + _validate_part_upload_inputs(part_number, body) + headers = _build_put_headers(access=access, content_type=content_type) + response = iter_coroutine( + SyncMultipartClient().upload_part( + upload_id=upload_id, + key=key, + path=path, + headers=headers, + part_number=part_number, + body=body, + on_upload_progress=on_upload_progress, + token=resolved_token, + ), ) - return MultipartPart(part_number=part_number, etag=resp["etag"]) + return _build_multipart_part_result(part_number, response) async def upload_part_async( @@ -121,22 +181,20 @@ async def upload_part_async( | None ) = None, ) -> MultipartPart: - token = ensure_token(token) - validate_path(path) - validate_access(access) - - headers = create_put_headers(content_type=content_type, access=access) - resp = await call_upload_part_async( + resolved_token = _validate_multipart_context(path, access, token) + _validate_part_upload_inputs(part_number, body) + headers = _build_put_headers(access=access, content_type=content_type) + response = await AsyncMultipartClient().upload_part( upload_id=upload_id, key=key, path=path, headers=headers, - token=token, part_number=part_number, body=body, on_upload_progress=on_upload_progress, + token=resolved_token, ) - return MultipartPart(part_number=part_number, etag=resp["etag"]) + return _build_multipart_part_result(part_number, response) def complete_multipart_upload( @@ -149,26 +207,19 @@ def complete_multipart_upload( upload_id: str, key: str, ) -> PutBlobResult: - token = ensure_token(token) - validate_path(path) - validate_access(access) - headers = create_put_headers(content_type=content_type, access=access) - - resp = call_complete_multipart_upload( - upload_id=upload_id, - key=key, - path=path, - headers=headers, - token=token, - parts=[{"partNumber": p.part_number, "etag": p.etag} for p in parts], - ) - return PutBlobResult( - url=resp["url"], - download_url=resp["downloadUrl"], - pathname=resp["pathname"], - content_type=resp["contentType"], - content_disposition=resp["contentDisposition"], + resolved_token = _validate_multipart_context(path, access, token) + headers = _build_put_headers(access=access, content_type=content_type) + response = iter_coroutine( + SyncMultipartClient().complete_multipart_upload( + upload_id=upload_id, + key=key, + path=path, + headers=headers, + parts=_normalize_complete_parts(parts), + token=resolved_token, + ), ) + return _build_put_blob_result(response) async def complete_multipart_upload_async( @@ -181,43 +232,20 @@ async def complete_multipart_upload_async( upload_id: str, key: str, ) -> PutBlobResult: - token = ensure_token(token) - validate_path(path) - validate_access(access) - - headers = create_put_headers(content_type=content_type, access=access) - resp = await call_complete_multipart_upload_async( + resolved_token = _validate_multipart_context(path, access, token) + headers = _build_put_headers(access=access, content_type=content_type) + response = await AsyncMultipartClient().complete_multipart_upload( upload_id=upload_id, key=key, path=path, headers=headers, - token=token, - parts=[{"partNumber": p.part_number, "etag": p.etag} for p in parts], + parts=_normalize_complete_parts(parts), + token=resolved_token, ) - return PutBlobResult( - url=resp["url"], - download_url=resp["downloadUrl"], - pathname=resp["pathname"], - content_type=resp["contentType"], - content_disposition=resp["contentDisposition"], - ) - - -class MultipartUploader: - """ - A convenience wrapper for multipart uploads that encapsulates the upload context. - - This provides a cleaner API than the manual approach where you have to pass - upload_id, key, pathname, etc. to every function call, while still giving you - control over when and how parts are uploaded (unlike the automatic flow). + return _build_put_blob_result(response) - Example: - >>> uploader = create_multipart_uploader("path/to/file.bin") - >>> part1 = uploader.upload_part(1, b"data chunk 1") - >>> part2 = uploader.upload_part(2, b"data chunk 2") - >>> result = uploader.complete([part1, part2]) - """ +class _BaseMultipartUploader: def __init__( self, path: str, @@ -225,12 +253,14 @@ def __init__( key: str, headers: PutHeaders | dict[str, str], token: str | None, + multipart_client: MultipartClient, ): self._path = path self._upload_id = upload_id self._key = key self._headers: dict[str, str] = cast(dict[str, str], headers) self._token = token + self._multipart_client = multipart_client @property def upload_id(self) -> str: @@ -242,6 +272,8 @@ def key(self) -> str: """The key (blob identifier) for this multipart upload.""" return self._key + +class MultipartUploader(_BaseMultipartUploader): def upload_part( self, part_number: int, @@ -250,116 +282,43 @@ def upload_part( on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, per_part_progress: Callable[[int, UploadProgressEvent], None] | None = None, ) -> MultipartPart: - """ - Upload a single part of the multipart upload. - - Args: - part_number: The part number (must be between 1 and 10,000) - body: The content to upload for this part (bytes, str, or file-like object) - on_upload_progress: Optional callback for upload progress tracking - - Returns: - A dict with 'partNumber' and 'etag' fields to pass to complete() - - Raises: - BlobError: If body is a plain dict/object - """ - if part_number < 1 or part_number > 10000: - raise BlobError("part_number must be between 1 and 10,000") - - if isinstance(body, dict) and not hasattr(body, "read"): - raise BlobError( - "Body must be a string, bytes, or file-like object. " - "You sent a plain dictionary, double check what you're trying to upload." - ) + _validate_part_upload_inputs(part_number, body) - # Compose per-part progress if provided effective = on_upload_progress if per_part_progress is not None and on_upload_progress is None: def effective(evt: UploadProgressEvent) -> None: per_part_progress(part_number, evt) - result = call_upload_part( - upload_id=self._upload_id, - key=self._key, - path=self._path, - headers=self._headers, - part_number=part_number, - body=body, - on_upload_progress=effective, - token=self._token, + result = iter_coroutine( + self._multipart_client.upload_part( + upload_id=self._upload_id, + key=self._key, + path=self._path, + headers=self._headers, + part_number=part_number, + body=body, + on_upload_progress=effective, + token=self._token, + ) ) - - return MultipartPart(part_number=part_number, etag=result["etag"]) + return _build_multipart_part_result(part_number, result) def complete(self, parts: list[MultipartPart]) -> PutBlobResult: - """ - Complete the multipart upload by assembling the uploaded parts. - - Args: - parts: List of parts returned from upload_part() calls. - Each part should have 'partNumber' and 'etag' fields. - - Returns: - The result of the completed upload with URL and metadata - """ - resp = call_complete_multipart_upload( - upload_id=self._upload_id, - key=self._key, - path=self._path, - headers=self._headers, - parts=[{"partNumber": p.part_number, "etag": p.etag} for p in parts], - token=self._token, - ) - return PutBlobResult( - url=resp["url"], - download_url=resp["downloadUrl"], - pathname=resp["pathname"], - content_type=resp["contentType"], - content_disposition=resp["contentDisposition"], + response = iter_coroutine( + self._multipart_client.complete_multipart_upload( + upload_id=self._upload_id, + key=self._key, + path=self._path, + headers=self._headers, + parts=_normalize_complete_parts(parts), + token=self._token, + ) ) + return _build_put_blob_result(response) -class AsyncMultipartUploader: - """ - An async convenience wrapper for multipart uploads that encapsulates the upload context. - - This provides a cleaner API than the manual approach where you have to pass - upload_id, key, pathname, etc. to every function call, while still giving you - control over when and how parts are uploaded (unlike the automatic flow). - - Example: - >>> uploader = await create_multipart_uploader_async("path/to/file.bin") - >>> part1 = await uploader.upload_part(1, b"data chunk 1") - >>> part2 = await uploader.upload_part(2, b"data chunk 2") - >>> result = await uploader.complete([part1, part2]) - """ - - def __init__( - self, - path: str, - upload_id: str, - key: str, - headers: PutHeaders | dict[str, str], - token: str | None, - ): - self._path = path - self._upload_id = upload_id - self._key = key - self._headers: dict[str, str] = cast(dict[str, str], headers) - self._token = token - - @property - def upload_id(self) -> str: - """The upload ID for this multipart upload.""" - return self._upload_id - - @property - def key(self) -> str: - """The key (blob identifier) for this multipart upload.""" - return self._key - +class AsyncMultipartUploader(_BaseMultipartUploader): async def upload_part( self, part_number: int, @@ -376,39 +335,17 @@ async def upload_part( | None ) = None, ) -> MultipartPart: - """ - Upload a single part of the multipart upload. - - Args: - part_number: The part number (must be between 1 and 10,000) - body: The content to upload for this part (bytes, str, or file-like object) - on_upload_progress: Optional callback for upload progress tracking - - Returns: - A dict with 'partNumber' and 'etag' fields to pass to complete() - - Raises: - BlobError: If body is a plain dict/object - """ - if part_number < 1 or part_number > 10000: - raise BlobError("part_number must be between 1 and 10,000") - - if isinstance(body, dict) and not hasattr(body, "read"): - raise BlobError( - "Body must be a string, bytes, or file-like object. " - "You sent a plain dictionary, double check what you're trying to upload." - ) + _validate_part_upload_inputs(part_number, body) - # Compose per-part progress if provided effective_progress = on_upload_progress if per_part_progress is not None and on_upload_progress is None: async def effective_progress(evt: UploadProgressEvent): - res = per_part_progress(part_number, evt) - if inspect.isawaitable(res): - await res + result = per_part_progress(part_number, evt) + if inspect.isawaitable(result): + await result - result = await call_upload_part_async( + response = await self._multipart_client.upload_part( upload_id=self._upload_id, key=self._key, path=self._path, @@ -418,35 +355,18 @@ async def effective_progress(evt: UploadProgressEvent): on_upload_progress=effective_progress, token=self._token, ) - - return MultipartPart(part_number=part_number, etag=result["etag"]) + return _build_multipart_part_result(part_number, response) async def complete(self, parts: list[MultipartPart]) -> PutBlobResult: - """ - Complete the multipart upload by assembling the uploaded parts. - - Args: - parts: List of parts returned from upload_part() calls. - Each part should have 'partNumber' and 'etag' fields. - - Returns: - The result of the completed upload with URL and metadata - """ - resp = await call_complete_multipart_upload_async( + response = await self._multipart_client.complete_multipart_upload( upload_id=self._upload_id, key=self._key, path=self._path, headers=self._headers, - parts=[{"partNumber": p.part_number, "etag": p.etag} for p in parts], + parts=_normalize_complete_parts(parts), token=self._token, ) - return PutBlobResult( - url=resp["url"], - download_url=resp["downloadUrl"], - pathname=resp["pathname"], - content_type=resp["contentType"], - content_disposition=resp["contentDisposition"], - ) + return _build_put_blob_result(response) def create_multipart_uploader( @@ -458,54 +378,28 @@ def create_multipart_uploader( overwrite: bool = False, cache_control_max_age: int | None = None, token: str | None = None, + multipart_client: MultipartClient | None = None, ) -> MultipartUploader: - """ - Create a multipart uploader with a cleaner API than the manual approach. - - It provides more control than the automatic approach (you control part creation - and concurrency) while being cleaner than the manual approach (no need to pass - upload_id, key, pathname to every call). - - Args: - path: The path inside the blob store (includes filename and extension) - access: Access level, defaults to "public" - content_type: The media type for the file (auto-detected from extension if not provided) - add_random_suffix: Whether to add a random suffix to the pathname (default: True) - overwrite: Whether to allow overwriting existing files (default: False) - cache_control_max_age: Cache duration in seconds (default: one year) - token: Authentication token (defaults to BLOB_READ_WRITE_TOKEN or - VERCEL_BLOB_READ_WRITE_TOKEN env var) - - Returns: - A MultipartUploader instance with upload_part() and complete() methods - - Example: - >>> uploader = create_multipart_uploader("large-file.bin") - >>> parts = [] - >>> for i, chunk in enumerate(chunks, start=1): - ... part = uploader.upload_part(i, chunk) - ... parts.append(part) - >>> result = uploader.complete(parts) - """ - token = ensure_token(token) - validate_path(path) - validate_access(access) - - headers = create_put_headers( + resolved_token = _validate_multipart_context(path, access, token) + headers = _build_put_headers( + access=access, content_type=content_type, add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, + overwrite=overwrite, cache_control_max_age=cache_control_max_age, - access=access, ) - create_resp = call_create_multipart_upload(path, headers, token=token) + effective_multipart_client = multipart_client or SyncMultipartClient() + create_response = iter_coroutine( + effective_multipart_client.create_multipart_upload(path, headers, token=resolved_token) + ) return MultipartUploader( path=path, - upload_id=create_resp["uploadId"], - key=create_resp["key"], + upload_id=create_response["uploadId"], + key=create_response["key"], headers=headers, - token=token, + token=resolved_token, + multipart_client=effective_multipart_client, ) @@ -518,52 +412,26 @@ async def create_multipart_uploader_async( overwrite: bool = False, cache_control_max_age: int | None = None, token: str | None = None, + multipart_client: MultipartClient | None = None, ) -> AsyncMultipartUploader: - """ - Create an async multipart uploader with a cleaner API than the manual approach. - - It provides more control than the automatic approach (you control part creation - and concurrency) while being cleaner than the manual approach (no need to pass - upload_id, key, pathname to every call). - - Args: - path: The path inside the blob store (includes filename and extension) - access: Access level, defaults to "public" - content_type: The media type for the file (auto-detected from extension if not provided) - add_random_suffix: Whether to add a random suffix to the pathname (default: True) - overwrite: Whether to allow overwriting existing files (default: False) - cache_control_max_age: Cache duration in seconds (default: one year) - token: Authentication token (defaults to BLOB_READ_WRITE_TOKEN or - VERCEL_BLOB_READ_WRITE_TOKEN env var) - - Returns: - An AsyncMultipartUploader instance with upload_part() and complete() methods - - Example: - >>> uploader = await create_multipart_uploader_async("large-file.bin") - >>> parts = [] - >>> for i, chunk in enumerate(chunks, start=1): - ... part = await uploader.upload_part(i, chunk) - ... parts.append(part) - >>> result = await uploader.complete(parts) - """ - token = ensure_token(token) - validate_path(path) - validate_access(access) - headers = create_put_headers( + resolved_token = _validate_multipart_context(path, access, token) + headers = _build_put_headers( + access=access, content_type=content_type, add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, + overwrite=overwrite, cache_control_max_age=cache_control_max_age, - access=access, ) - - create_resp = await call_create_multipart_upload_async(path, headers, token=token) + effective_multipart_client = multipart_client or AsyncMultipartClient() + create_response = await effective_multipart_client.create_multipart_upload( + path, headers, token=resolved_token + ) return AsyncMultipartUploader( path=path, - upload_id=create_resp["uploadId"], - key=create_resp["key"], + upload_id=create_response["uploadId"], + key=create_response["key"], headers=headers, - token=token, + token=resolved_token, + multipart_client=effective_multipart_client, ) diff --git a/src/vercel/blob/multipart/core.py b/src/vercel/blob/multipart/core.py deleted file mode 100644 index c9f8b75..0000000 --- a/src/vercel/blob/multipart/core.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -from collections.abc import Awaitable, Callable -from typing import Any, cast -from urllib.parse import quote - -from ..api import request_api, request_api_async -from ..utils import PutHeaders, UploadProgressEvent - - -def call_create_multipart_upload( - path: str, headers: PutHeaders | dict[str, str], *, token: str | None = None -) -> dict[str, str]: - params = {"pathname": path} - request_headers = cast(dict[str, str], headers).copy() - request_headers["x-mpu-action"] = "create" - return request_api( - "/mpu", - "POST", - token=token, - headers=request_headers, - params=params, - ) - - -async def call_create_multipart_upload_async( - path: str, headers: PutHeaders | dict[str, str], *, token: str | None = None -) -> dict[str, str]: - params = {"pathname": path} - request_headers = cast(dict[str, str], headers).copy() - request_headers["x-mpu-action"] = "create" - return await request_api_async( - "/mpu", - "POST", - token=token, - headers=request_headers, - params=params, - ) - - -def call_upload_part( - *, - upload_id: str, - key: str, - path: str, - headers: PutHeaders | dict[str, str], - part_number: int, - body: Any, - on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, - token: str | None = None, -): - params = {"pathname": path} - request_headers = cast(dict[str, str], headers).copy() - request_headers["x-mpu-action"] = "upload" - request_headers["x-mpu-key"] = quote(key, safe="") - request_headers["x-mpu-upload-id"] = upload_id - request_headers["x-mpu-part-number"] = str(part_number) - return request_api( - "/mpu", - "POST", - token=token, - headers=request_headers, - params=params, - body=body, - on_upload_progress=on_upload_progress, - ) - - -async def call_upload_part_async( - *, - upload_id: str, - key: str, - path: str, - headers: PutHeaders | dict[str, str], - part_number: int, - body: Any, - on_upload_progress: ( - Callable[[UploadProgressEvent], None] - | Callable[[UploadProgressEvent], Awaitable[None]] - | None - ) = None, - token: str | None = None, -): - params = {"pathname": path} - request_headers = cast(dict[str, str], headers).copy() - request_headers["x-mpu-action"] = "upload" - request_headers["x-mpu-key"] = quote(key, safe="") - request_headers["x-mpu-upload-id"] = upload_id - request_headers["x-mpu-part-number"] = str(part_number) - return await request_api_async( - "/mpu", - "POST", - token=token, - headers=request_headers, - params=params, - body=body, - on_upload_progress=on_upload_progress, - ) - - -def call_complete_multipart_upload( - *, - upload_id: str, - key: str, - path: str, - headers: PutHeaders | dict[str, str], - parts: list[dict[str, Any]], - token: str | None = None, -) -> dict[str, Any]: - params = {"pathname": path} - request_headers = cast(dict[str, str], headers).copy() - request_headers["content-type"] = "application/json" - request_headers["x-mpu-action"] = "complete" - request_headers["x-mpu-upload-id"] = upload_id - request_headers["x-mpu-key"] = quote(key, safe="") - return request_api( - "/mpu", - "POST", - token=token, - headers=request_headers, - params=params, - body=parts, - ) - - -async def call_complete_multipart_upload_async( - *, - upload_id: str, - key: str, - path: str, - headers: PutHeaders | dict[str, str], - parts: list[dict[str, Any]], - token: str | None = None, -) -> dict[str, Any]: - params = {"pathname": path} - request_headers = cast(dict[str, str], headers).copy() - request_headers["content-type"] = "application/json" - request_headers["x-mpu-action"] = "complete" - request_headers["x-mpu-upload-id"] = upload_id - request_headers["x-mpu-key"] = quote(key, safe="") - return await request_api_async( - "/mpu", - "POST", - token=token, - headers=request_headers, - params=params, - body=parts, - ) diff --git a/src/vercel/blob/multipart/uploader.py b/src/vercel/blob/multipart/uploader.py index 033bff6..b70c8b1 100644 --- a/src/vercel/blob/multipart/uploader.py +++ b/src/vercel/blob/multipart/uploader.py @@ -1,110 +1,23 @@ from __future__ import annotations -import asyncio -import threading -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator -from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from collections.abc import Awaitable, Callable from typing import Any -from ..errors import BlobError -from ..utils import ( - Access, - UploadProgressEvent, - compute_body_length, - create_put_headers, - validate_access, +from vercel._internal.blob import compute_body_length +from vercel._internal.blob.multipart import ( + DEFAULT_PART_SIZE, + AsyncMultipartClient, + MultipartUploadSession, + SyncMultipartClient, + create_async_multipart_upload_runtime, + create_sync_multipart_upload_runtime, + order_uploaded_parts, + prepare_upload_headers, + shape_complete_upload_result, + validate_part_size, ) -from .core import ( - call_complete_multipart_upload, - call_complete_multipart_upload_async, - call_create_multipart_upload, - call_create_multipart_upload_async, - call_upload_part, - call_upload_part_async, -) - -DEFAULT_PART_SIZE = 8 * 1024 * 1024 # 8MB -MIN_PART_SIZE = 5 * 1024 * 1024 # 5 MiB minimum for most backends; last part may be smaller -MAX_CONCURRENCY = 6 - - -def _validate_part_size(part_size: int) -> int: - ps = int(part_size) - if ps < MIN_PART_SIZE: - raise BlobError(f"part_size must be at least {MIN_PART_SIZE} bytes (5 MiB)") - return ps - - -def _iter_part_bytes(body: Any, part_size: int) -> Iterator[bytes]: - # bytes-like - if isinstance(body, (bytes, bytearray, memoryview)): - view = memoryview(body) - offset = 0 - while offset < len(view): - end = min(offset + part_size, len(view)) - yield bytes(view[offset:end]) - offset = end - return - # str - if isinstance(body, str): - data = body.encode("utf-8") - view = memoryview(data) - offset = 0 - while offset < len(view): - end = min(offset + part_size, len(view)) - yield bytes(view[offset:end]) - offset = end - return - # file-like object - if hasattr(body, "read"): - while True: - chunk = body.read(part_size) # type: ignore[attr-defined] - if not chunk: - break - if not isinstance(chunk, (bytes, bytearray, memoryview)): - chunk = bytes(chunk) - yield bytes(chunk) - return - # Iterable[bytes] - if isinstance(body, Iterable): # type: ignore[arg-type] - buffer = bytearray() - for ch in body: # type: ignore[assignment] - if not isinstance(ch, (bytes, bytearray, memoryview)): - ch = bytes(ch) - buffer.extend(ch) - while len(buffer) >= part_size: - yield bytes(buffer[:part_size]) - del buffer[:part_size] - if buffer: - yield bytes(buffer) - return - # Fallback: coerce to bytes and slice - data = bytes(body) - view = memoryview(data) - offset = 0 - while offset < len(view): - end = min(offset + part_size, len(view)) - yield bytes(view[offset:end]) - offset = end - - -async def _aiter_part_bytes(body: Any, part_size: int) -> AsyncIterator[bytes]: - # AsyncIterable[bytes] - if hasattr(body, "__aiter__"): - buffer = bytearray() - async for ch in body: # type: ignore[misc] - if not isinstance(ch, (bytes, bytearray, memoryview)): - ch = bytes(ch) - buffer.extend(ch) - while len(buffer) >= part_size: - yield bytes(buffer[:part_size]) - del buffer[:part_size] - if buffer: - yield bytes(buffer) - return - # Delegate to sync iterator for other cases - for chunk in _iter_part_bytes(body, part_size): - yield chunk +from vercel._internal.iter_coroutine import iter_coroutine +from vercel.blob.types import Access, UploadProgressEvent def auto_multipart_upload( @@ -120,81 +33,49 @@ def auto_multipart_upload( on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, part_size: int = DEFAULT_PART_SIZE, ) -> dict[str, Any]: - validate_access(access) - headers = create_put_headers( + client = SyncMultipartClient() + headers = prepare_upload_headers( + access=access, content_type=content_type, add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, + overwrite=overwrite, cache_control_max_age=cache_control_max_age, - access=access, ) + part_size = validate_part_size(part_size) - part_size = _validate_part_size(part_size) - - create_resp = call_create_multipart_upload(path, headers, token=token) - upload_id = create_resp["uploadId"] - key = create_resp["key"] - - total = compute_body_length(body) - loaded_per_part: dict[int, int] = {} - loaded_lock = threading.Lock() - results: list[dict] = [] - - def upload_one(part_number: int, content: bytes) -> dict: - def progress(evt: UploadProgressEvent) -> None: - with loaded_lock: - loaded_per_part[part_number] = int(evt.loaded) - if on_upload_progress: - loaded = sum(loaded_per_part.values()) - pct = round((loaded / total) * 100, 2) if total else 0.0 - on_upload_progress( - UploadProgressEvent(loaded=loaded, total=total, percentage=pct) - ) - - resp = call_upload_part( - upload_id=upload_id, - key=key, - path=path, - headers=headers, - token=token, - part_number=part_number, - body=content, - on_upload_progress=progress, - ) - return {"partNumber": part_number, "etag": resp["etag"]} - - with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor: - inflight = set() - part_no = 1 - for chunk in _iter_part_bytes(body, part_size): - fut = executor.submit(upload_one, part_no, chunk) - inflight.add(fut) - part_no += 1 - if len(inflight) >= MAX_CONCURRENCY: - done, inflight = wait(inflight, return_when=FIRST_COMPLETED) - for d in done: - results.append(d.result()) - - if inflight: - done, _ = wait(inflight) - for d in done: - results.append(d.result()) - - # Ensure parts are ordered by partNumber - results.sort(key=lambda p: int(p["partNumber"])) - - if on_upload_progress: - on_upload_progress(UploadProgressEvent(loaded=total, total=total, percentage=100.0)) - - return call_complete_multipart_upload( - upload_id=upload_id, - key=key, + create_response = iter_coroutine(client.create_multipart_upload(path, headers, token=token)) + session = MultipartUploadSession( + upload_id=create_response["uploadId"], + key=create_response["key"], path=path, headers=headers, token=token, - parts=results, ) + runtime = create_sync_multipart_upload_runtime() + total = compute_body_length(body) + parts = runtime.upload( + session=session, + body=body, + part_size=part_size, + total=total, + on_upload_progress=on_upload_progress, + upload_part_fn=lambda **kwargs: iter_coroutine(client.upload_part(**kwargs)), + ) + ordered_parts = order_uploaded_parts(parts) + + complete_response = iter_coroutine( + client.complete_multipart_upload( + upload_id=session.upload_id, + key=session.key, + path=session.path, + headers=session.headers, + token=session.token, + parts=ordered_parts, + ) + ) + return shape_complete_upload_result(complete_response) + async def auto_multipart_upload_async( path: str, @@ -213,93 +94,43 @@ async def auto_multipart_upload_async( ) = None, part_size: int = DEFAULT_PART_SIZE, ) -> dict[str, Any]: - validate_access(access) - headers = create_put_headers( + client = AsyncMultipartClient() + headers = prepare_upload_headers( + access=access, content_type=content_type, add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, + overwrite=overwrite, cache_control_max_age=cache_control_max_age, - access=access, ) + part_size = validate_part_size(part_size) - part_size = _validate_part_size(part_size) - - create_resp = await call_create_multipart_upload_async(path, headers, token=token) - upload_id = create_resp["uploadId"] - key = create_resp["key"] - - total = compute_body_length(body) - loaded_per_part: dict[int, int] = {} - results: list[dict] = [] - - def _make_progress(part_number: int): - if on_upload_progress and asyncio.iscoroutinefunction(on_upload_progress): - - async def progress_async(evt: UploadProgressEvent): - loaded_per_part[part_number] = int(evt.loaded) - loaded = sum(loaded_per_part.values()) - pct = round((loaded / total) * 100, 2) if total else 0.0 - await on_upload_progress( - UploadProgressEvent(loaded=loaded, total=total, percentage=pct) - ) - - return progress_async - else: - - def progress(evt: UploadProgressEvent) -> None: - loaded_per_part[part_number] = int(evt.loaded) - if on_upload_progress: - loaded = sum(loaded_per_part.values()) - pct = round((loaded / total) * 100, 2) if total else 0.0 - on_upload_progress( - UploadProgressEvent(loaded=loaded, total=total, percentage=pct) - ) - - return progress - - async def upload_one(part_number: int, content: bytes) -> dict: - resp = await call_upload_part_async( - upload_id=upload_id, - key=key, - path=path, - headers=headers, - part_number=part_number, - body=content, - on_upload_progress=_make_progress(part_number), - token=token, - ) - return {"partNumber": part_number, "etag": resp["etag"]} - - inflight: set[asyncio.Task] = set() - part_no = 1 - async for chunk in _aiter_part_bytes(body, part_size): - t = asyncio.create_task(upload_one(part_no, chunk)) - inflight.add(t) - part_no += 1 - if len(inflight) >= MAX_CONCURRENCY: - done, inflight = await asyncio.wait(inflight, return_when=asyncio.FIRST_COMPLETED) - for d in done: - results.append(d.result()) - - if inflight: - done, _ = await asyncio.wait(inflight, return_when=asyncio.ALL_COMPLETED) - for d in done: - results.append(d.result()) - - results.sort(key=lambda p: int(p["partNumber"])) - - if on_upload_progress: - loaded = sum(loaded_per_part.values()) - pct = round((loaded / total) * 100, 2) if total else 100.0 - result = on_upload_progress(UploadProgressEvent(loaded=loaded, total=total, percentage=pct)) - if asyncio.iscoroutine(result): - await result - - return await call_complete_multipart_upload_async( - upload_id=upload_id, - key=key, + create_response = await client.create_multipart_upload(path, headers, token=token) + session = MultipartUploadSession( + upload_id=create_response["uploadId"], + key=create_response["key"], path=path, headers=headers, token=token, - parts=results, ) + + runtime = create_async_multipart_upload_runtime() + total = compute_body_length(body) + parts = await runtime.upload( + session=session, + body=body, + part_size=part_size, + total=total, + on_upload_progress=on_upload_progress, + upload_part_fn=client.upload_part, + ) + ordered_parts = order_uploaded_parts(parts) + + complete_response = await client.complete_multipart_upload( + upload_id=session.upload_id, + key=session.key, + path=session.path, + headers=session.headers, + token=session.token, + parts=ordered_parts, + ) + return shape_complete_upload_result(complete_response) diff --git a/src/vercel/blob/ops.py b/src/vercel/blob/ops.py index 8ef7fc0..aafac52 100644 --- a/src/vercel/blob/ops.py +++ b/src/vercel/blob/ops.py @@ -1,49 +1,39 @@ from __future__ import annotations -import contextvars -import inspect -import os -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator -from datetime import datetime, timezone -from email.utils import parsedate_to_datetime +from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine, Iterable, Iterator from os import PathLike -from typing import Any -from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +from typing import Any, TypeVar -import httpx - -from .._telemetry.tracker import telemetry, track -from .api import request_api, request_api_async -from .errors import BlobError, BlobNotFoundError -from .multipart import auto_multipart_upload, auto_multipart_upload_async -from .types import ( +from vercel._internal.blob import ( + ensure_token, + validate_access, +) +from vercel._internal.blob.core import ( + AsyncBlobOpsClient, + SyncBlobOpsClient, + normalize_delete_urls, +) +from vercel._internal.iter_coroutine import iter_coroutine +from vercel.blob.types import ( + Access, CreateFolderResult as CreateFolderResultType, GetBlobResult as GetBlobResultType, HeadBlobResult as HeadBlobResultType, ListBlobItem, ListBlobResult as ListBlobResultType, PutBlobResult as PutBlobResultType, -) -from .utils import ( - Access, - PutHeaders, UploadProgressEvent, - construct_blob_url, - create_put_headers, - ensure_token, - extract_store_id_from_token, - get_download_url, - is_url, - parse_datetime, - validate_access, - validate_path, ) -# Context variable to store the delete count for telemetry -# This allows the derive function to access the count after the iterable is consumed -_delete_count_context: contextvars.ContextVar[int | None] = contextvars.ContextVar( - "_delete_count", default=None -) +_T = TypeVar("_T") + + +def _run_sync_blob_operation( + operation: Callable[[SyncBlobOpsClient], Coroutine[None, None, _T]], +) -> _T: + with SyncBlobOpsClient() as client: + # Keep exactly one sync bridge at the wrapper boundary. + return iter_coroutine(operation(client)) def put( @@ -60,27 +50,8 @@ def put( on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, ) -> PutBlobResultType: token = ensure_token(token) - validate_path(path) - validate_access(access) - - if body is None: - raise BlobError("body is required") - if isinstance(body, dict): - raise BlobError( - "Body must be a string, buffer or stream. " - "You sent a plain object, double check what you're trying to upload." - ) - - headers = create_put_headers( - content_type=content_type, - add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - access=access, - ) - - if multipart is True: - raw = auto_multipart_upload( + result, _ = _run_sync_blob_operation( + lambda client: client.put_blob( path, body, access=access, @@ -89,61 +60,11 @@ def put( overwrite=overwrite, cache_control_max_age=cache_control_max_age, token=token, + multipart=multipart, on_upload_progress=on_upload_progress, ) - # Track telemetry (best-effort) - size_bytes = None - if isinstance(body, (bytes, bytearray)): - size_bytes = len(body) - elif isinstance(body, str): - size_bytes = len(body.encode()) - track( - "blob_put", - token=token, - access=access, - content_type=content_type, - multipart=True, - size_bytes=size_bytes, - ) - return PutBlobResultType( - url=raw["url"], - download_url=raw["downloadUrl"], - pathname=raw["pathname"], - content_type=raw["contentType"], - content_disposition=raw["contentDisposition"], - ) - - params = {"pathname": path} - raw = request_api( - "", - "PUT", - token=token, - headers=headers, - params=params, - body=body, - on_upload_progress=on_upload_progress, - ) - # Track telemetry (best-effort) - size_bytes = None - if isinstance(body, (bytes, bytearray)): - size_bytes = len(body) - elif isinstance(body, str): - size_bytes = len(body.encode()) - track( - "blob_put", - token=token, - access=access, - content_type=content_type, - multipart=False, - size_bytes=size_bytes, - ) - return PutBlobResultType( - url=raw["url"], - download_url=raw["downloadUrl"], - pathname=raw["pathname"], - content_type=raw["contentType"], - content_disposition=raw["contentDisposition"], ) + return result async def put_async( @@ -164,29 +85,8 @@ async def put_async( ) = None, ) -> PutBlobResultType: token = ensure_token(token) - validate_path(path) - validate_access(access) - - if body is None: - raise BlobError("body is required") - # Reject plain dict (JS plain object equivalent) to match TS error semantics - if isinstance(body, dict): - raise BlobError( - "Body must be a string, buffer or stream. " - "You sent a plain object, double check what you're trying to upload." - ) - - headers: PutHeaders = create_put_headers( - content_type=content_type, - add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - access=access, - ) - - # Multipart auto support - if multipart is True: - raw = await auto_multipart_upload_async( + async with AsyncBlobOpsClient() as client: + result, _ = await client.put_blob( path, body, access=access, @@ -195,301 +95,59 @@ async def put_async( overwrite=overwrite, cache_control_max_age=cache_control_max_age, token=token, + multipart=multipart, on_upload_progress=on_upload_progress, ) - # Track telemetry - size_bytes = None - if isinstance(body, (bytes, bytearray)): - size_bytes = len(body) - elif isinstance(body, str): - size_bytes = len(body.encode()) - track( - "blob_put", - token=token, - access=access, - content_type=content_type, - multipart=True, - size_bytes=size_bytes, - ) - return PutBlobResultType( - url=raw["url"], - download_url=raw["downloadUrl"], - pathname=raw["pathname"], - content_type=raw["contentType"], - content_disposition=raw["contentDisposition"], - ) - - params = {"pathname": path} - raw = await request_api_async( - "", - "PUT", - token=token, - headers=headers, - params=params, - body=body, - on_upload_progress=on_upload_progress, - ) - # Track telemetry - size_bytes = None - if isinstance(body, (bytes, bytearray)): - size_bytes = len(body) - elif isinstance(body, str): - size_bytes = len(body.encode()) - track( - "blob_put", - token=token, - access=access, - content_type=content_type, - multipart=False, - size_bytes=size_bytes, - ) - return PutBlobResultType( - url=raw["url"], - download_url=raw["downloadUrl"], - pathname=raw["pathname"], - content_type=raw["contentType"], - content_disposition=raw["contentDisposition"], - ) - - -class _CountedIterable: - """Wrapper for iterables that preserves count even after consumption. + return result - This is used to handle generators and other single-use iterables - passed to delete/delete_async. The wrapper converts the iterable - to a list once, so that the count is preserved even after the - iterable is fully consumed by the function. - """ - def __init__(self, iterable: Iterable[str]) -> None: - """Convert the iterable to a list to preserve it for later counting.""" - self.items = [str(item) for item in iterable] - - def __iter__(self) -> Iterator[str]: - """Allow iteration over the preserved items.""" - return iter(self.items) - - def __len__(self) -> int: - """Return the count of items.""" - return len(self.items) - - -def _derive_delete_count(args: tuple, kwargs: dict, result: Any) -> int: - """Derive the count of URLs being deleted.""" - # First, check if the count was stored in the context variable - count = _delete_count_context.get() - if count is not None: - _delete_count_context.set(None) # Clear it for the next call - return count - - # Fallback: try to derive from the argument - url_or_path = kwargs.get("url_or_path", args[0] if args else None) - if url_or_path is None: - return 1 - # Check if it's a _CountedIterable (which preserves count after consumption) - if isinstance(url_or_path, _CountedIterable): - return len(url_or_path) - # For other iterables, try to count them (though they may be exhausted) - if isinstance(url_or_path, Iterable) and not isinstance(url_or_path, (str, bytes)): - try: - return len(list(url_or_path)) - except Exception: - return 1 - return 1 - - -@telemetry( - event="blob_delete", - capture=["token"], - derive={"count": _derive_delete_count}, - when="after", -) def delete( url_or_path: str | Iterable[str], *, token: str | None = None, ) -> None: token = ensure_token(token) - # Convert iterables to a list and store the count for telemetry - if isinstance(url_or_path, Iterable) and not isinstance(url_or_path, (str, bytes)): - urls = [str(u) for u in url_or_path] - _delete_count_context.set(len(urls)) - else: - urls = [str(url_or_path)] - _delete_count_context.set(1) - - request_api( - "/delete", - "POST", - token=token, - headers={"content-type": "application/json"}, - body={"urls": urls}, + normalized_urls = normalize_delete_urls(url_or_path) + _run_sync_blob_operation( + lambda client: client.delete_blob( + normalized_urls, + token=token, + ) ) -@telemetry( - event="blob_delete", - capture=["token"], - derive={"count": _derive_delete_count}, - when="after", -) async def delete_async( url_or_path: str | Iterable[str], *, token: str | None = None, ) -> None: token = ensure_token(token) - # Convert iterables to a list and store the count for telemetry - if isinstance(url_or_path, Iterable) and not isinstance(url_or_path, (str, bytes)): - urls = [str(u) for u in url_or_path] - _delete_count_context.set(len(urls)) - else: - urls = [str(url_or_path)] - _delete_count_context.set(1) - - await request_api_async( - "/delete", - "POST", - token=token, - headers={"content-type": "application/json"}, - body={"urls": urls}, - ) + normalized_urls = normalize_delete_urls(url_or_path) + async with AsyncBlobOpsClient() as client: + await client.delete_blob( + normalized_urls, + token=token, + ) def head(url_or_path: str, *, token: str | None = None) -> HeadBlobResultType: token = ensure_token(token) - params = {"url": url_or_path} - resp = request_api( - "", - "GET", - token=token, - params=params, - ) - uploaded_at = ( - parse_datetime(resp["uploadedAt"]) - if isinstance(resp.get("uploadedAt"), str) - else resp["uploadedAt"] - ) - return HeadBlobResultType( - size=resp["size"], - uploaded_at=uploaded_at, - pathname=resp["pathname"], - content_type=resp["contentType"], - content_disposition=resp["contentDisposition"], - url=resp["url"], - download_url=resp["downloadUrl"], - cache_control=resp["cacheControl"], + return _run_sync_blob_operation( + lambda client: client.head_blob( + url_or_path, + token=token, + ) ) async def head_async(url_or_path: str, *, token: str | None = None) -> HeadBlobResultType: token = ensure_token(token) - params = {"url": url_or_path} - resp = await request_api_async( - "", - "GET", - token=token, - params=params, - ) - uploaded_at = ( - parse_datetime(resp["uploadedAt"]) - if isinstance(resp.get("uploadedAt"), str) - else resp["uploadedAt"] - ) - return HeadBlobResultType( - size=resp["size"], - uploaded_at=uploaded_at, - pathname=resp["pathname"], - content_type=resp["contentType"], - content_disposition=resp["contentDisposition"], - url=resp["url"], - download_url=resp["downloadUrl"], - cache_control=resp["cacheControl"], - ) - - -def _resolve_blob_url(url_or_path: str, token: str, access: Access) -> tuple[str, str]: - """Resolve a URL or pathname to a blob URL and extract the pathname. - - Returns (blob_url, pathname). - """ - if is_url(url_or_path): - parsed = urlparse(url_or_path) - pathname = parsed.path.lstrip("/") - return url_or_path, pathname - - # It's a pathname - construct the URL from token store ID - store_id = extract_store_id_from_token(token) - if not store_id: - raise BlobError( - "Unable to extract store ID from token. " - "When using a pathname instead of a full URL, " - "a valid token with an embedded store ID is required." - ) - pathname = url_or_path.lstrip("/") - blob_url = construct_blob_url(store_id, pathname, access) - return blob_url, pathname - - -def _parse_last_modified(value: str | None) -> datetime: - """Parse a Last-Modified header (RFC 7231 or ISO 8601).""" - if not value: - return datetime.now(tz=timezone.utc) - try: - return parsedate_to_datetime(value) - except (ValueError, TypeError): - pass - try: - return parse_datetime(value) - except (ValueError, TypeError): - return datetime.now(tz=timezone.utc) - - -def _build_get_result( - resp: httpx.Response, blob_url: str, pathname: str -) -> GetBlobResultType: - """Build a GetBlobResult from an httpx response.""" - if resp.status_code == 304: - return GetBlobResultType( - url=blob_url, - download_url=get_download_url(blob_url), - pathname=pathname, - content_type=None, - size=None, - content_disposition=resp.headers.get("content-disposition", ""), - cache_control=resp.headers.get("cache-control", ""), - uploaded_at=_parse_last_modified(resp.headers.get("last-modified")), - etag=resp.headers.get("etag", ""), - content=b"", - status_code=304, + async with AsyncBlobOpsClient() as client: + return await client.head_blob( + url_or_path, + token=token, ) - content_length = resp.headers.get("content-length") - return GetBlobResultType( - url=blob_url, - download_url=get_download_url(blob_url), - pathname=pathname, - content_type=resp.headers.get("content-type", "application/octet-stream"), - size=int(content_length) if content_length else len(resp.content), - content_disposition=resp.headers.get("content-disposition", ""), - cache_control=resp.headers.get("cache-control", ""), - uploaded_at=_parse_last_modified(resp.headers.get("last-modified")), - etag=resp.headers.get("etag", ""), - content=resp.content, - status_code=resp.status_code, - ) - - -def _build_cache_bypass_url(blob_url: str) -> str: - parsed = urlparse(blob_url) - params = parse_qs(parsed.query) - params["cache"] = ["0"] - query = urlencode(params, doseq=True) - return urlunparse(( - parsed.scheme, parsed.netloc, parsed.path, - parsed.params, query, parsed.fragment, - )) - def get( url_or_path: str, @@ -502,29 +160,17 @@ def get( ) -> GetBlobResultType: token = ensure_token(token) validate_access(access) - blob_url, pathname = _resolve_blob_url(url_or_path, token, access) - - headers: dict[str, str] = {} - if access == "private": - headers["authorization"] = f"Bearer {token}" - if if_none_match: - headers["if-none-match"] = if_none_match - - fetch_url = _build_cache_bypass_url(blob_url) if not use_cache else blob_url - - try: - with httpx.Client(follow_redirects=True, timeout=httpx.Timeout(timeout or 30.0)) as client: - resp = client.get(fetch_url, headers=headers) - if resp.status_code == 404: - raise BlobNotFoundError() - if resp.status_code == 304: - return _build_get_result(resp, blob_url, pathname) - resp.raise_for_status() - return _build_get_result(resp, blob_url, pathname) - except httpx.HTTPStatusError as exc: - if exc.response is not None and exc.response.status_code == 404: - raise BlobNotFoundError() from exc - raise + return _run_sync_blob_operation( + lambda client: client.get_blob( + url_or_path, + access=access, + token=token, + timeout=timeout, + use_cache=use_cache, + if_none_match=if_none_match, + default_timeout=30.0, + ) + ) async def get_async( @@ -538,31 +184,16 @@ async def get_async( ) -> GetBlobResultType: token = ensure_token(token) validate_access(access) - blob_url, pathname = _resolve_blob_url(url_or_path, token, access) - - headers: dict[str, str] = {} - if access == "private": - headers["authorization"] = f"Bearer {token}" - if if_none_match: - headers["if-none-match"] = if_none_match - - fetch_url = _build_cache_bypass_url(blob_url) if not use_cache else blob_url - - try: - async with httpx.AsyncClient( - follow_redirects=True, timeout=httpx.Timeout(timeout or 120.0) - ) as client: - resp = await client.get(fetch_url, headers=headers) - if resp.status_code == 404: - raise BlobNotFoundError() - if resp.status_code == 304: - return _build_get_result(resp, blob_url, pathname) - resp.raise_for_status() - return _build_get_result(resp, blob_url, pathname) - except httpx.HTTPStatusError as exc: - if exc.response is not None and exc.response.status_code == 404: - raise BlobNotFoundError() from exc - raise + async with AsyncBlobOpsClient() as client: + return await client.get_blob( + url_or_path, + access=access, + token=token, + timeout=timeout, + use_cache=use_cache, + if_none_match=if_none_match, + default_timeout=30.0, + ) def list_objects( @@ -574,44 +205,14 @@ def list_objects( token: str | None = None, ) -> ListBlobResultType: token = ensure_token(token) - params: dict[str, Any] = {} - if limit is not None: - params["limit"] = int(limit) - if prefix is not None: - params["prefix"] = prefix - if cursor is not None: - params["cursor"] = cursor - if mode is not None: - params["mode"] = mode - - resp = request_api( - "", - "GET", - token=token, - params=params, - ) - blobs_list: list[ListBlobItem] = [] - for b in resp.get("blobs", []): - uploaded_at = ( - parse_datetime(b["uploadedAt"]) - if isinstance(b.get("uploadedAt"), str) - else b["uploadedAt"] - ) - blobs_list.append( - ListBlobItem( - url=b["url"], - download_url=b["downloadUrl"], - pathname=b["pathname"], - size=b["size"], - uploaded_at=uploaded_at, - ) + with SyncBlobOpsClient() as client: + return client.list_objects( + limit=limit, + prefix=prefix, + cursor=cursor, + mode=mode, + token=token, ) - return ListBlobResultType( - blobs=blobs_list, - cursor=resp.get("cursor"), - has_more=resp.get("hasMore", False), - folders=resp.get("folders"), - ) async def list_objects_async( @@ -623,44 +224,14 @@ async def list_objects_async( token: str | None = None, ) -> ListBlobResultType: token = ensure_token(token) - params: dict[str, Any] = {} - if limit is not None: - params["limit"] = int(limit) - if prefix is not None: - params["prefix"] = prefix - if cursor is not None: - params["cursor"] = cursor - if mode is not None: - params["mode"] = mode - - resp = await request_api_async( - "", - "GET", - token=token, - params=params, - ) - blobs_list: list[ListBlobItem] = [] - for b in resp.get("blobs", []): - uploaded_at = ( - parse_datetime(b["uploadedAt"]) - if isinstance(b.get("uploadedAt"), str) - else b["uploadedAt"] - ) - blobs_list.append( - ListBlobItem( - url=b["url"], - download_url=b["downloadUrl"], - pathname=b["pathname"], - size=b["size"], - uploaded_at=uploaded_at, - ) + async with AsyncBlobOpsClient() as client: + return await client.list_objects( + limit=limit, + prefix=prefix, + cursor=cursor, + mode=mode, + token=token, ) - return ListBlobResultType( - blobs=blobs_list, - cursor=resp.get("cursor"), - has_more=resp.get("hasMore", False), - folders=resp.get("folders"), - ) def iter_objects( @@ -673,32 +244,15 @@ def iter_objects( cursor: str | None = None, ) -> Iterator[ListBlobItem]: token = ensure_token(token) - next_cursor = cursor - yielded_count = 0 - while True: - effective_limit: int | None = batch_size - if limit is not None: - remaining = limit - yielded_count - if remaining <= 0: - break - if effective_limit is None or effective_limit > remaining: - effective_limit = remaining - page = list_objects( - limit=effective_limit, + with SyncBlobOpsClient() as client: + yield from client.iter_objects( prefix=prefix, - cursor=next_cursor, mode=mode, token=token, + batch_size=batch_size, + limit=limit, + cursor=cursor, ) - for item in page.blobs: - yield item - if limit is not None: - yielded_count += 1 - if yielded_count >= limit: - return - if not page.has_more or not page.cursor: - break - next_cursor = page.cursor async def iter_objects_async( @@ -711,32 +265,16 @@ async def iter_objects_async( cursor: str | None = None, ) -> AsyncIterator[ListBlobItem]: token = ensure_token(token) - next_cursor = cursor - yielded_count = 0 - while True: - effective_limit: int | None = batch_size - if limit is not None: - remaining = limit - yielded_count - if remaining <= 0: - break - if effective_limit is None or effective_limit > remaining: - effective_limit = remaining - page = await list_objects_async( - limit=effective_limit, + async with AsyncBlobOpsClient() as client: + async for item in client.iter_objects( prefix=prefix, - cursor=next_cursor, mode=mode, token=token, - ) - for item in page.blobs: + batch_size=batch_size, + limit=limit, + cursor=cursor, + ): yield item - if limit is not None: - yielded_count += 1 - if yielded_count >= limit: - return - if not page.has_more or not page.cursor: - break - next_cursor = page.cursor def copy( @@ -751,33 +289,17 @@ def copy( token: str | None = None, ) -> PutBlobResultType: token = ensure_token(token) - validate_path(dst_path) - validate_access(access) - if not is_url(src_path): - meta = head(src_path, token=token) - src_path = meta.url - - headers: PutHeaders = create_put_headers( - content_type=content_type, - add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - access=access, - ) - params = {"pathname": dst_path, "fromUrl": src_path} - raw = request_api( - "", - "PUT", - token=token, - headers=headers, - params=params, - ) - return PutBlobResultType( - url=raw["url"], - download_url=raw["downloadUrl"], - pathname=raw["pathname"], - content_type=raw["contentType"], - content_disposition=raw["contentDisposition"], + return _run_sync_blob_operation( + lambda client: client.copy_blob( + src_path, + dst_path, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=token, + ) ) @@ -793,36 +315,17 @@ async def copy_async( token: str | None = None, ) -> PutBlobResultType: token = ensure_token(token) - validate_path(dst_path) - validate_access(access) - - if not is_url(src_path): - meta = head(src_path, token=token) - src_path = meta.url - dst_path = str(dst_path) - - headers: PutHeaders = create_put_headers( - content_type=content_type, - add_random_suffix=add_random_suffix, - allow_overwrite=overwrite, - cache_control_max_age=cache_control_max_age, - access=access, - ) - params = {"pathname": dst_path, "fromUrl": src_path} - raw = await request_api_async( - "", - "PUT", - token=token, - headers=headers, - params=params, - ) - return PutBlobResultType( - url=raw["url"], - download_url=raw["downloadUrl"], - pathname=raw["pathname"], - content_type=raw["contentType"], - content_disposition=raw["contentDisposition"], - ) + async with AsyncBlobOpsClient() as client: + return await client.copy_blob( + src_path, + dst_path, + access=access, + content_type=content_type, + add_random_suffix=add_random_suffix, + overwrite=overwrite, + cache_control_max_age=cache_control_max_age, + token=token, + ) def create_folder( @@ -832,20 +335,13 @@ def create_folder( overwrite: bool = False, ) -> CreateFolderResultType: token = ensure_token(token) - folder_path = path if path.endswith("/") else path + "/" - headers = create_put_headers( - add_random_suffix=False, - allow_overwrite=overwrite, - ) - params = {"pathname": folder_path} - raw = request_api( - "", - "PUT", - token=token, - headers=headers, - params=params, + return _run_sync_blob_operation( + lambda client: client.create_folder( + path, + token=token, + overwrite=overwrite, + ) ) - return CreateFolderResultType(pathname=raw["pathname"], url=raw["url"]) async def create_folder_async( @@ -855,20 +351,12 @@ async def create_folder_async( overwrite: bool = False, ) -> CreateFolderResultType: token = ensure_token(token) - folder_path = path if path.endswith("/") else path + "/" - headers = create_put_headers( - add_random_suffix=False, - allow_overwrite=overwrite, - ) - params = {"pathname": folder_path} - raw = await request_api_async( - "", - "PUT", - token=token, - headers=headers, - params=params, - ) - return CreateFolderResultType(pathname=raw["pathname"], url=raw["url"]) + async with AsyncBlobOpsClient() as client: + return await client.create_folder( + path, + token=token, + overwrite=overwrite, + ) def upload_file( @@ -884,33 +372,21 @@ def upload_file( multipart: bool = False, on_upload_progress: Callable[[UploadProgressEvent], None] | None = None, ) -> PutBlobResultType: - token = ensure_token(token) - if not local_path: - raise BlobError("src_path is required") - if not path: - raise BlobError("path is required") - if not os.path.exists(os.fspath(local_path)): - raise BlobError("local_path does not exist") - if not os.path.isfile(os.fspath(local_path)): - raise BlobError("local_path is not a file") - - # Auto-enable multipart if file size exceeds 5 MiB - size_bytes = os.path.getsize(os.fspath(local_path)) - use_multipart = multipart or (size_bytes > 5 * 1024 * 1024) - - with open(os.fspath(local_path), "rb") as f: - return put( + return _run_sync_blob_operation( + lambda client: client.upload_file( + local_path, path, - f, access=access, content_type=content_type, add_random_suffix=add_random_suffix, overwrite=overwrite, cache_control_max_age=cache_control_max_age, token=token, - multipart=use_multipart, + multipart=multipart, on_upload_progress=on_upload_progress, + missing_local_path_error="src_path is required", ) + ) async def upload_file_async( @@ -930,32 +406,19 @@ async def upload_file_async( | None ) = None, ) -> PutBlobResultType: - token = ensure_token(token) - if not local_path: - raise BlobError("local_path is required") - if not path: - raise BlobError("path is required") - if not os.path.exists(os.fspath(local_path)): - raise BlobError("local_path does not exist") - if not os.path.isfile(os.fspath(local_path)): - raise BlobError("local_path is not a file") - - # Auto-enable multipart if file size exceeds 5 MiB - size_bytes = os.path.getsize(os.fspath(local_path)) - use_multipart = multipart or (size_bytes > 5 * 1024 * 1024) - - with open(os.fspath(local_path), "rb") as f: - return await put_async( + async with AsyncBlobOpsClient() as client: + return await client.upload_file( + local_path, path, - f, access=access, content_type=content_type, add_random_suffix=add_random_suffix, overwrite=overwrite, cache_control_max_age=cache_control_max_age, token=token, - multipart=use_multipart, + multipart=multipart, on_upload_progress=on_upload_progress, + missing_local_path_error="local_path is required", ) @@ -972,48 +435,18 @@ def download_file( ) -> str: token = ensure_token(token) validate_access(access) - - # Resolve remote URL from url_or_path - blob_url, _ = _resolve_blob_url(url_or_path, token, access) - target_url = blob_url - - # Prepare destination - dst = os.fspath(local_path) - if not overwrite and os.path.exists(dst): - raise BlobError("destination exists; pass overwrite=True to replace it") - if create_parents: - os.makedirs(os.path.dirname(dst) or ".", exist_ok=True) - - tmp = dst + ".part" - bytes_read = 0 - - req_headers: dict[str, str] = {} - if access == "private": - req_headers["authorization"] = f"Bearer {token}" - - try: - with httpx.Client(follow_redirects=True, timeout=httpx.Timeout(timeout or 120.0)) as client: - with client.stream("GET", target_url, headers=req_headers) as resp: - if resp.status_code == 404: - raise BlobNotFoundError() - resp.raise_for_status() - total = int(resp.headers.get("Content-Length", "0")) or None - with open(tmp, "wb") as f: - for chunk in resp.iter_bytes(): - if chunk: - f.write(chunk) - bytes_read += len(chunk) - if progress: - progress(bytes_read, total) - - os.replace(tmp, dst) # atomic finalize - except Exception: - try: - if os.path.exists(tmp): - os.remove(tmp) - finally: - raise - return dst + return _run_sync_blob_operation( + lambda client: client.download_file( + url_or_path, + local_path, + access=access, + token=token, + timeout=timeout, + overwrite=overwrite, + create_parents=create_parents, + progress=progress, + ) + ) async def download_file_async( @@ -1031,52 +464,14 @@ async def download_file_async( ) -> str: token = ensure_token(token) validate_access(access) - - # Resolve remote URL from url_or_path - blob_url, _ = _resolve_blob_url(url_or_path, token, access) - target_url = blob_url - - # Prepare destination - dst = os.fspath(local_path) - if not overwrite and os.path.exists(dst): - raise BlobError("destination exists; pass overwrite=True to replace it") - if create_parents: - os.makedirs(os.path.dirname(dst) or ".", exist_ok=True) - - tmp = dst + ".part" - bytes_read = 0 - - req_headers: dict[str, str] = {} - if access == "private": - req_headers["authorization"] = f"Bearer {token}" - - try: - async with ( - httpx.AsyncClient( - follow_redirects=True, - timeout=httpx.Timeout(timeout or 120.0), - ) as client, - client.stream("GET", target_url, headers=req_headers) as resp, - ): - if resp.status_code == 404: - raise BlobNotFoundError() - resp.raise_for_status() - total = int(resp.headers.get("Content-Length", "0")) or None - with open(tmp, "wb") as f: - async for chunk in resp.aiter_bytes(): - if chunk: - f.write(chunk) - bytes_read += len(chunk) - if progress: - maybe = progress(bytes_read, total) - if inspect.isawaitable(maybe): - await maybe - - os.replace(tmp, dst) # atomic finalize - except Exception: - try: - if os.path.exists(tmp): - os.remove(tmp) - finally: - raise - return dst + async with AsyncBlobOpsClient() as client: + return await client.download_file( + url_or_path, + local_path, + access=access, + token=token, + timeout=timeout, + overwrite=overwrite, + create_parents=create_parents, + progress=progress, + ) diff --git a/src/vercel/blob/types.py b/src/vercel/blob/types.py index 10a7d12..991bd1d 100644 --- a/src/vercel/blob/types.py +++ b/src/vercel/blob/types.py @@ -1,75 +1,27 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime - - -@dataclass(slots=True) -class PutBlobResult: - url: str - download_url: str - pathname: str - content_type: str - content_disposition: str - - -@dataclass(slots=True) -class HeadBlobResult: - size: int - uploaded_at: datetime - pathname: str - content_type: str - content_disposition: str - url: str - download_url: str - cache_control: str - - -@dataclass(slots=True) -class ListBlobItem: - url: str - download_url: str - pathname: str - size: int - uploaded_at: datetime - - -@dataclass(slots=True) -class ListBlobResult: - blobs: list[ListBlobItem] - cursor: str | None - has_more: bool - folders: list[str] | None = None - - -@dataclass(slots=True) -class CreateFolderResult: - pathname: str - url: str - - -@dataclass(slots=True) -class MultipartCreateResult: - upload_id: str - key: str - - -@dataclass(slots=True) -class GetBlobResult: - url: str - download_url: str - pathname: str - content_type: str | None - size: int | None - content_disposition: str - cache_control: str - uploaded_at: datetime - etag: str - content: bytes - status_code: int - - -@dataclass(slots=True) -class MultipartPart: - part_number: int - etag: str +from vercel._internal.blob.types import ( + Access, + CreateFolderResult, + GetBlobResult, + HeadBlobResult, + ListBlobItem, + ListBlobResult, + MultipartCreateResult, + MultipartPart, + OnUploadProgressCallback, + PutBlobResult, + UploadProgressEvent, +) + +__all__ = [ + "Access", + "CreateFolderResult", + "GetBlobResult", + "HeadBlobResult", + "ListBlobItem", + "ListBlobResult", + "MultipartCreateResult", + "MultipartPart", + "OnUploadProgressCallback", + "PutBlobResult", + "UploadProgressEvent", +] diff --git a/src/vercel/cache/cache_build.py b/src/vercel/cache/cache_build.py index ecf99b2..6510dd0 100644 --- a/src/vercel/cache/cache_build.py +++ b/src/vercel/cache/cache_build.py @@ -5,7 +5,8 @@ import httpx -from .._telemetry.tracker import track +from vercel._internal.telemetry.tracker import track + from .types import AsyncCache, Cache HEADERS_VERCEL_CACHE_STATE = "x-vercel-cache-state" diff --git a/src/vercel/cache/cache_in_memory.py b/src/vercel/cache/cache_in_memory.py index 4df434f..765dac0 100644 --- a/src/vercel/cache/cache_in_memory.py +++ b/src/vercel/cache/cache_in_memory.py @@ -2,7 +2,8 @@ from collections.abc import Sequence -from .._telemetry.tracker import track +from vercel._internal.telemetry.tracker import track + from .types import AsyncCache, Cache diff --git a/src/vercel/deployments/deployments.py b/src/vercel/deployments/deployments.py index c70af29..3acdd0f 100644 --- a/src/vercel/deployments/deployments.py +++ b/src/vercel/deployments/deployments.py @@ -5,7 +5,7 @@ import httpx -from .._telemetry.tracker import track +from vercel._internal.telemetry.tracker import track DEFAULT_API_BASE_URL = "https://api.vercel.com" DEFAULT_TIMEOUT = 60.0 diff --git a/src/vercel/projects/projects.py b/src/vercel/projects/projects.py index 957ae34..ba80f58 100644 --- a/src/vercel/projects/projects.py +++ b/src/vercel/projects/projects.py @@ -6,7 +6,7 @@ import httpx -from .._telemetry.tracker import track +from vercel._internal.telemetry.tracker import track __all__ = [ "get_projects", @@ -269,7 +269,7 @@ def update_project( team_id: str | None = None, slug: str | None = None, base_url: str = DEFAULT_API_BASE_URL, - timeout: float = 30.0, + timeout: float = DEFAULT_TIMEOUT, ) -> dict[str, Any]: """Update an existing project by id or name.""" params: dict[str, Any] = {} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..632cbcc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,102 @@ +"""Shared fixtures for all tests.""" + +import os +import time +import uuid +from collections.abc import Generator + +import pytest + + +@pytest.fixture +def mock_env_clear(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: + """Clear all Vercel-related environment variables for testing. + + This ensures tests don't accidentally use real credentials from the environment. + """ + env_vars_to_clear = [ + # General Vercel + "VERCEL_TOKEN", + "VERCEL_TEAM_ID", + "VERCEL_PROJECT_ID", + # Blob storage + "BLOB_READ_WRITE_TOKEN", + "BLOB_STORE_ID", + # OIDC + "VERCEL_OIDC_TOKEN", + "VERCEL_OIDC_TOKEN_HEADER", + # Cache + "VERCEL_CACHE_API_TOKEN", + "VERCEL_CACHE_API_URL", + # Functions + "VERCEL_URL", + "VERCEL_ENV", + "VERCEL_REGION", + ] + for var in env_vars_to_clear: + monkeypatch.delenv(var, raising=False) + yield + + +@pytest.fixture +def mock_token() -> str: + """Mock Vercel API token for testing.""" + return "test_token_123456789" + + +@pytest.fixture +def mock_team_id() -> str: + """Mock Vercel team ID for testing.""" + return "team_test123456789" + + +@pytest.fixture +def mock_project_id() -> str: + """Mock Vercel project ID for testing.""" + return "prj_test123456789" + + +@pytest.fixture +def mock_blob_token() -> str: + """Mock blob storage token for testing.""" + return "vercel_blob_rw_test_token_123456789" + + +@pytest.fixture +def unique_test_name() -> str: + """Generate a unique test resource name with timestamp.""" + timestamp = int(time.time()) + unique_id = uuid.uuid4().hex[:8] + return f"vercel-py-test-{timestamp}-{unique_id}" + + +def has_vercel_credentials() -> bool: + """Check if Vercel API credentials are available.""" + return bool(os.getenv("VERCEL_TOKEN") and os.getenv("VERCEL_TEAM_ID")) + + +def has_blob_credentials() -> bool: + """Check if Blob storage credentials are available.""" + return bool(os.getenv("BLOB_READ_WRITE_TOKEN")) + + +def has_sandbox_credentials() -> bool: + """Check if Sandbox credentials are available.""" + return has_vercel_credentials() + + +# Skip markers for live tests +requires_vercel_credentials = pytest.mark.skipif( + not has_vercel_credentials(), + reason="Requires VERCEL_TOKEN and VERCEL_TEAM_ID environment variables", +) + +requires_blob_credentials = pytest.mark.skipif( + not has_blob_credentials(), + reason="Requires BLOB_READ_WRITE_TOKEN environment variable", +) + +requires_sandbox_credentials = pytest.mark.skipif( + not has_sandbox_credentials(), + reason="Requires VERCEL_TOKEN and VERCEL_TEAM_ID environment variables for sandbox", +) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..d89aa36 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,386 @@ +"""Fixtures for integration tests using respx mocking.""" + +import pytest + +# API base URLs +BLOB_API_BASE = "https://blob.vercel-storage.com" +VERCEL_API_BASE = "https://api.vercel.com" +SANDBOX_API_BASE = "https://sandbox.vercel.com" + + +# ============================================================================= +# Blob Module Mock Responses +# ============================================================================= + + +@pytest.fixture +def mock_blob_put_response() -> dict: + """Mock response for blob put operation.""" + return { + "url": f"{BLOB_API_BASE}/test-abc123/test.txt", + "downloadUrl": f"{BLOB_API_BASE}/test-abc123/test.txt?download=1", + "pathname": "test.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="test.txt"', + } + + +@pytest.fixture +def mock_blob_head_response() -> dict: + """Mock response for blob head operation.""" + return { + "url": f"{BLOB_API_BASE}/test-abc123/test.txt", + "downloadUrl": f"{BLOB_API_BASE}/test-abc123/test.txt?download=1", + "pathname": "test.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="test.txt"', + "size": 13, + "uploadedAt": "2024-01-15T10:30:00.000Z", + "cacheControl": "max-age=31536000", + } + + +@pytest.fixture +def mock_blob_list_response() -> dict: + """Mock response for blob list operation.""" + return { + "blobs": [ + { + "url": f"{BLOB_API_BASE}/test-abc123/file1.txt", + "downloadUrl": f"{BLOB_API_BASE}/test-abc123/file1.txt?download=1", + "pathname": "file1.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="file1.txt"', + "size": 100, + "uploadedAt": "2024-01-15T10:30:00.000Z", + }, + { + "url": f"{BLOB_API_BASE}/test-abc123/file2.txt", + "downloadUrl": f"{BLOB_API_BASE}/test-abc123/file2.txt?download=1", + "pathname": "file2.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="file2.txt"', + "size": 200, + "uploadedAt": "2024-01-15T10:31:00.000Z", + }, + ], + "cursor": None, + "hasMore": False, + "folders": [], + } + + +@pytest.fixture +def mock_blob_list_response_paginated() -> dict: + """Mock response for paginated blob list operation.""" + return { + "blobs": [ + { + "url": f"{BLOB_API_BASE}/test-abc123/page1.txt", + "downloadUrl": f"{BLOB_API_BASE}/test-abc123/page1.txt?download=1", + "pathname": "page1.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="page1.txt"', + "size": 50, + "uploadedAt": "2024-01-15T10:30:00.000Z", + }, + ], + "cursor": "next_cursor_abc123", + "hasMore": True, + "folders": [], + } + + +@pytest.fixture +def mock_blob_create_folder_response() -> dict: + """Mock response for create folder operation.""" + return { + "url": f"{BLOB_API_BASE}/test-abc123/my-folder/", + "pathname": "my-folder/", + } + + +@pytest.fixture +def mock_blob_copy_response() -> dict: + """Mock response for blob copy operation.""" + return { + "url": f"{BLOB_API_BASE}/test-abc123/copied.txt", + "downloadUrl": f"{BLOB_API_BASE}/test-abc123/copied.txt?download=1", + "pathname": "copied.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="copied.txt"', + } + + +# ============================================================================= +# Sandbox Module Mock Responses +# ============================================================================= + + +@pytest.fixture +def mock_sandbox_create_response() -> dict: + """Mock response for sandbox create operation. + + Matches the Sandbox model schema with all required fields. + """ + return { + "id": "sbx_test123456", + "memory": 512, + "vcpus": 1, + "region": "iad1", + "runtime": "nodejs20.x", + "timeout": 300, + "status": "running", + "requestedAt": 1705320600000, + "startedAt": 1705320601000, + "requestedStopAt": None, + "stoppedAt": None, + "duration": None, + "sourceSnapshotId": None, + "snapshottedAt": None, + "createdAt": 1705320600000, + "cwd": "/app", + "updatedAt": 1705320601000, + "interactivePort": None, + } + + +@pytest.fixture +def mock_sandbox_get_response() -> dict: + """Mock response for sandbox get operation.""" + return { + "id": "sbx_test123456", + "memory": 512, + "vcpus": 1, + "region": "iad1", + "runtime": "nodejs20.x", + "timeout": 300, + "status": "running", + "requestedAt": 1705320600000, + "startedAt": 1705320601000, + "requestedStopAt": None, + "stoppedAt": None, + "duration": None, + "sourceSnapshotId": None, + "snapshottedAt": None, + "createdAt": 1705320600000, + "cwd": "/app", + "updatedAt": 1705320601000, + "interactivePort": None, + } + + +@pytest.fixture +def mock_sandbox_command_response() -> dict: + """Mock response for sandbox run_command operation.""" + return { + "commandId": "cmd_test123", + "exitCode": 0, + "stdout": "Hello, World!\n", + "stderr": "", + } + + +@pytest.fixture +def mock_sandbox_command_detached_response() -> dict: + """Mock response for sandbox run_command_detached operation.""" + return { + "commandId": "cmd_detached_test123", + "status": "running", + } + + +@pytest.fixture +def mock_sandbox_read_file_content() -> bytes: + """Mock content for sandbox read_file operation.""" + return b"file content from sandbox" + + +@pytest.fixture +def mock_sandbox_snapshot_response() -> dict: + """Mock response for sandbox snapshot operation. + + Matches the Snapshot model schema with all required fields. + """ + return { + "id": "snap_test123456", + "sourceSandboxId": "sbx_test123456", + "region": "iad1", + "status": "created", + "sizeBytes": 1024000, + "expiresAt": 1705924600000, + "createdAt": 1705320600000, + "updatedAt": 1705320600000, + } + + +# ============================================================================= +# Cache Module Mock Responses +# ============================================================================= + + +@pytest.fixture +def mock_cache_get_response() -> dict: + """Mock response for cache get operation.""" + return { + "value": "cached_value_123", + "status": "HIT", + } + + +@pytest.fixture +def mock_cache_set_response() -> dict: + """Mock response for cache set operation.""" + return { + "status": "OK", + } + + +@pytest.fixture +def mock_cache_delete_response() -> dict: + """Mock response for cache delete operation.""" + return { + "status": "OK", + } + + +# ============================================================================= +# OIDC Module Mock Responses +# ============================================================================= + + +@pytest.fixture +def mock_oidc_token() -> str: + """Mock OIDC JWT token for testing. + + This is a valid JWT structure with a test payload. + Header: {"alg": "RS256", "typ": "JWT"} + Payload: {"sub": "test_subject", "aud": "vercel", + "iss": "https://oidc.vercel.com", "exp": 9999999999} + """ + # This is a properly formatted JWT (though not cryptographically valid) + # Base64url encoded: header.payload.signature + header = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9" + # fmt: off + payload = "eyJzdWIiOiJ0ZXN0X3N1YmplY3QiLCJhdWQiOiJ2ZXJjZWwiLCJpc3MiOiJodHRwczovL29pZGMudmVyY2VsLmNvbSIsImV4cCI6OTk5OTk5OTk5OX0" # noqa: E501 + # fmt: on + signature = "test_signature" + return f"{header}.{payload}.{signature}" + + +@pytest.fixture +def mock_oidc_token_payload() -> dict: + """Mock decoded OIDC token payload.""" + return { + "sub": "test_subject", + "aud": "vercel", + "iss": "https://oidc.vercel.com", + "exp": 9999999999, + } + + +# ============================================================================= +# Projects Module Mock Responses (for reference/compatibility) +# ============================================================================= + + +@pytest.fixture +def mock_project_data() -> dict: + """Mock project data based on actual API response structure.""" + return { + "id": "prj_abc123456789", + "name": "test-project", + "accountId": "team_test123456789", + "createdAt": 1705320600000, + "updatedAt": 1705320600000, + "framework": None, + "devCommand": None, + "installCommand": None, + "buildCommand": None, + "outputDirectory": None, + "rootDirectory": None, + "nodeVersion": "20.x", + "serverlessFunctionRegion": None, + "sourceFilesOutsideRootDirectory": False, + "speedInsights": None, + "webAnalytics": None, + "autoAssignCustomDomains": True, + "autoAssignCustomDomainsUpdatedBy": None, + "gitForkProtection": True, + "directoryListing": False, + "skewProtectionBoundaryAt": None, + "skewProtectionMaxAge": None, + } + + +@pytest.fixture +def mock_projects_list_response(mock_project_data: dict) -> dict: + """Mock projects list response with pagination.""" + return { + "projects": [mock_project_data], + "pagination": { + "count": 1, + "next": None, + "prev": None, + }, + } + + +# ============================================================================= +# Error Response Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_error_not_found() -> dict: + """Mock 404 Not Found error response.""" + return { + "error": { + "code": "not_found", + "message": "The requested resource was not found.", + } + } + + +@pytest.fixture +def mock_error_unauthorized() -> dict: + """Mock 401 Unauthorized error response.""" + return { + "error": { + "code": "unauthorized", + "message": "Authentication required.", + } + } + + +@pytest.fixture +def mock_error_forbidden() -> dict: + """Mock 403 Forbidden error response.""" + return { + "error": { + "code": "forbidden", + "message": "You do not have permission to access this resource.", + } + } + + +@pytest.fixture +def mock_error_bad_request() -> dict: + """Mock 400 Bad Request error response.""" + return { + "error": { + "code": "bad_request", + "message": "The request was invalid.", + } + } + + +@pytest.fixture +def mock_error_server_error() -> dict: + """Mock 500 Internal Server Error response.""" + return { + "error": { + "code": "internal_server_error", + "message": "An unexpected error occurred.", + } + } diff --git a/tests/integration/test_blob_multipart_auto_upload.py b/tests/integration/test_blob_multipart_auto_upload.py new file mode 100644 index 0000000..dc851ef --- /dev/null +++ b/tests/integration/test_blob_multipart_auto_upload.py @@ -0,0 +1,335 @@ +"""Integration tests for multipart auto upload using respx.""" + +from __future__ import annotations + +import json +from collections.abc import Callable +from typing import Any + +import httpx +import pytest +import respx + +from vercel._internal.blob.multipart import MIN_PART_SIZE +from vercel.blob.multipart import ( + AsyncMultipartUploader, + MultipartUploader, + auto_multipart_upload, + auto_multipart_upload_async, + complete_multipart_upload, + complete_multipart_upload_async, + create_multipart_upload, + create_multipart_upload_async, + create_multipart_uploader, + create_multipart_uploader_async, + upload_part, + upload_part_async, +) +from vercel.blob.types import UploadProgressEvent + +BLOB_API_BASE = "https://vercel.com/api/blob" + + +def _build_complete_response(pathname: str) -> dict[str, str]: + return { + "url": f"https://blob.vercel-storage.com/test-abc123/{pathname}", + "downloadUrl": f"https://blob.vercel-storage.com/test-abc123/{pathname}?download=1", + "pathname": pathname, + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="file.bin"', + } + + +def _manual_mpu_handler( + pathname: str, +) -> tuple[Callable[[httpx.Request], httpx.Response], dict[str, Any]]: + upload_part_numbers: list[int] = [] + completed_parts: list[dict[str, str | int]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + + if action == "create": + assert request.url.params["pathname"] == pathname + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + part_number = int(request.headers["x-mpu-part-number"]) + upload_part_numbers.append(part_number) + assert request.headers["x-mpu-upload-id"] == "upload-id" + assert request.headers["x-mpu-key"] == "blob-key" + return httpx.Response(200, json={"etag": f"etag-{part_number}"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response(200, json=_build_complete_response(pathname)) + + raise AssertionError(f"unexpected multipart action: {action}") + + state = { + "upload_part_numbers": upload_part_numbers, + "completed_parts": completed_parts, + } + return handler, state + + +@respx.mock +def test_auto_multipart_upload_sync_uses_blob_api_flow(mock_env_clear) -> None: + upload_part_numbers: list[int] = [] + upload_part_lengths: list[int] = [] + completed_parts: list[dict[str, str | int]] = [] + progress_events: list[UploadProgressEvent] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + + if action == "create": + assert request.url.params["pathname"] == "folder/file.bin" + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + part_number = int(request.headers["x-mpu-part-number"]) + upload_part_numbers.append(part_number) + upload_part_lengths.append(len(request.content)) + assert request.headers["x-mpu-upload-id"] == "upload-id" + assert request.headers["x-mpu-key"] == "blob-key" + return httpx.Response(200, json={"etag": f"etag-{part_number}"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response(200, json=_build_complete_response("folder/file.bin")) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + + body = (b"a" * MIN_PART_SIZE) + b"b" + result = auto_multipart_upload( + "folder/file.bin", + body, + token="test_token", + part_size=MIN_PART_SIZE, + on_upload_progress=progress_events.append, + ) + + assert route.call_count == 4 + assert sorted(upload_part_numbers) == [1, 2] + assert sorted(upload_part_lengths) == [1, MIN_PART_SIZE] + assert [part["partNumber"] for part in completed_parts] == [1, 2] + assert result["pathname"] == "folder/file.bin" + assert isinstance(progress_events[-1], UploadProgressEvent) + assert progress_events[-1] == UploadProgressEvent( + loaded=len(body), + total=len(body), + percentage=100.0, + ) + + +@respx.mock +def test_manual_multipart_sync_uses_blob_api_flow(mock_env_clear) -> None: + handler, state = _manual_mpu_handler("folder/manual.bin") + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=handler) + + created = create_multipart_upload("folder/manual.bin", token="test_token") + part = upload_part( + "folder/manual.bin", + b"chunk", + token="test_token", + upload_id=created.upload_id, + key=created.key, + part_number=1, + ) + result = complete_multipart_upload( + "folder/manual.bin", + [part], + token="test_token", + upload_id=created.upload_id, + key=created.key, + ) + + assert route.call_count == 3 + assert state["upload_part_numbers"] == [1] + assert [part["partNumber"] for part in state["completed_parts"]] == [1] + assert result.pathname == "folder/manual.bin" + + +@respx.mock +@pytest.mark.asyncio +async def test_auto_multipart_upload_async_uses_blob_api_flow(mock_env_clear) -> None: + upload_part_numbers: list[int] = [] + upload_part_lengths: list[int] = [] + completed_parts: list[dict[str, str | int]] = [] + progress_events: list[UploadProgressEvent] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + + if action == "create": + assert request.url.params["pathname"] == "folder/file.bin" + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + part_number = int(request.headers["x-mpu-part-number"]) + upload_part_numbers.append(part_number) + upload_part_lengths.append(len(request.content)) + assert request.headers["x-mpu-upload-id"] == "upload-id" + assert request.headers["x-mpu-key"] == "blob-key" + return httpx.Response(200, json={"etag": f"etag-{part_number}"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response(200, json=_build_complete_response("folder/file.bin")) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + + body = (b"a" * MIN_PART_SIZE) + b"b" + + async def on_progress(event: UploadProgressEvent) -> None: + progress_events.append(event) + + result = await auto_multipart_upload_async( + "folder/file.bin", + body, + token="test_token", + part_size=MIN_PART_SIZE, + on_upload_progress=on_progress, + ) + + assert route.call_count == 4 + assert sorted(upload_part_numbers) == [1, 2] + assert sorted(upload_part_lengths) == [1, MIN_PART_SIZE] + assert [part["partNumber"] for part in completed_parts] == [1, 2] + assert result["pathname"] == "folder/file.bin" + assert isinstance(progress_events[-1], UploadProgressEvent) + assert progress_events[-1] == UploadProgressEvent( + loaded=len(body), + total=len(body), + percentage=100.0, + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_auto_multipart_upload_async_unknown_total_reports_loaded_bytes( + mock_env_clear, +) -> None: + upload_part_numbers: list[int] = [] + progress_events: list[UploadProgressEvent] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + + if action == "create": + assert request.url.params["pathname"] == "folder/unknown-total.bin" + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + part_number = int(request.headers["x-mpu-part-number"]) + upload_part_numbers.append(part_number) + return httpx.Response(200, json={"etag": f"etag-{part_number}"}) + + if action == "complete": + return httpx.Response(200, json=_build_complete_response("folder/unknown-total.bin")) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + + chunk_one = b"a" * (MIN_PART_SIZE // 2) + chunk_two = b"b" + + async def async_chunks(): + yield chunk_one + yield chunk_two + + async def on_progress(event: UploadProgressEvent) -> None: + progress_events.append(event) + + result = await auto_multipart_upload_async( + "folder/unknown-total.bin", + async_chunks(), + token="test_token", + part_size=MIN_PART_SIZE, + on_upload_progress=on_progress, + ) + + assert route.call_count == 3 + assert upload_part_numbers == [1] + assert result["pathname"] == "folder/unknown-total.bin" + assert progress_events[-1] == UploadProgressEvent( + loaded=len(chunk_one) + len(chunk_two), + total=0, + percentage=100.0, + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_manual_multipart_async_uses_blob_api_flow(mock_env_clear) -> None: + handler, state = _manual_mpu_handler("folder/manual-async.bin") + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=handler) + + created = await create_multipart_upload_async("folder/manual-async.bin", token="test_token") + part = await upload_part_async( + "folder/manual-async.bin", + b"chunk", + token="test_token", + upload_id=created.upload_id, + key=created.key, + part_number=1, + ) + result = await complete_multipart_upload_async( + "folder/manual-async.bin", + [part], + token="test_token", + upload_id=created.upload_id, + key=created.key, + ) + + assert route.call_count == 3 + assert state["upload_part_numbers"] == [1] + assert [part["partNumber"] for part in state["completed_parts"]] == [1] + assert result.pathname == "folder/manual-async.bin" + + +@respx.mock +def test_create_multipart_uploader_sync_uses_blob_api_flow(mock_env_clear) -> None: + handler, state = _manual_mpu_handler("folder/uploader-sync.bin") + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=handler) + + uploader = create_multipart_uploader("folder/uploader-sync.bin", token="test_token") + assert isinstance(uploader, MultipartUploader) + assert uploader.upload_id == "upload-id" + assert uploader.key == "blob-key" + + part = uploader.upload_part(1, b"chunk") + result = uploader.complete([part]) + + assert route.call_count == 3 + assert state["upload_part_numbers"] == [1] + assert [part["partNumber"] for part in state["completed_parts"]] == [1] + assert result.pathname == "folder/uploader-sync.bin" + + +@respx.mock +@pytest.mark.asyncio +async def test_create_multipart_uploader_async_uses_blob_api_flow(mock_env_clear) -> None: + handler, state = _manual_mpu_handler("folder/uploader-async.bin") + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=handler) + + uploader = await create_multipart_uploader_async( + "folder/uploader-async.bin", token="test_token" + ) + assert isinstance(uploader, AsyncMultipartUploader) + assert uploader.upload_id == "upload-id" + assert uploader.key == "blob-key" + + part = await uploader.upload_part(1, b"chunk") + result = await uploader.complete([part]) + + assert route.call_count == 3 + assert state["upload_part_numbers"] == [1] + assert [part["partNumber"] for part in state["completed_parts"]] == [1] + assert result.pathname == "folder/uploader-async.bin" diff --git a/tests/integration/test_blob_sync_async.py b/tests/integration/test_blob_sync_async.py new file mode 100644 index 0000000..2151ab2 --- /dev/null +++ b/tests/integration/test_blob_sync_async.py @@ -0,0 +1,1562 @@ +"""Integration tests for Vercel Blob API using respx mocking. + +Tests both sync and async variants to ensure API parity. +""" + +import io + +import httpx +import pytest +import respx + +from vercel._internal.blob.core import decode_blob_response_json +from vercel.blob import ( + AsyncBlobClient, + BlobClient, + BlobNotFoundError, + aioblob, + copy, + copy_async, + create_folder, + create_folder_async, + delete, + delete_async, + download_file, + download_file_async, + get_download_url, + head, + head_async, + iter_objects, + iter_objects_async, + list_objects, + list_objects_async, + put, + put_async, + upload_file, + upload_file_async, +) +from vercel.blob.ops import get, get_async + +# Base URL for Vercel Blob API +BLOB_API_BASE = "https://vercel.com/api/blob" + + +class TestBlobPut: + """Test blob put operations.""" + + @respx.mock + def test_put_sync(self, mock_env_clear, mock_blob_put_response): + """Test synchronous blob upload.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + + result = put("test.txt", b"Hello, World!", token="test_token") + + assert route.called + assert result.url == mock_blob_put_response["url"] + assert result.pathname == "test.txt" + assert result.content_type == "text/plain" + + # Verify request had correct headers + request = route.calls.last.request + assert "Bearer test_token" in request.headers.get("authorization", "") + + @respx.mock + @pytest.mark.asyncio + async def test_put_async(self, mock_env_clear, mock_blob_put_response): + """Test asynchronous blob upload.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + + result = await put_async("test.txt", b"Hello, World!", token="test_token") + + assert route.called + assert result.url == mock_blob_put_response["url"] + assert result.pathname == "test.txt" + + @respx.mock + def test_put_sync_progress_per_chunk(self, mock_env_clear, mock_blob_put_response): + """Test sync put emits multiple progress callbacks for streamed file-like input.""" + payload = b"a" * (64 * 1024) + b"b" * 32 + progress_percentages: list[float] = [] + + def handler(request: httpx.Request) -> httpx.Response: + body = b"".join(request.stream) # type: ignore[arg-type] + assert body == payload + return httpx.Response(200, json=mock_blob_put_response) + + route = respx.put(BLOB_API_BASE).mock(side_effect=handler) + + put( + "test.txt", + io.BytesIO(payload), + token="test_token", + on_upload_progress=lambda event: progress_percentages.append(event.percentage), + ) + + assert route.called + assert len(progress_percentages) >= 2 + assert progress_percentages[-1] == 100.0 + assert any(percentage < 100.0 for percentage in progress_percentages) + + @respx.mock + @pytest.mark.asyncio + async def test_put_async_progress_per_chunk(self, mock_env_clear, mock_blob_put_response): + """Test async put emits multiple progress callbacks for streamed file-like input.""" + payload = b"a" * (64 * 1024) + b"b" * 32 + progress_percentages: list[float] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + body = b"" + async for chunk in request.stream: # type: ignore[union-attr] + body += chunk + assert body == payload + return httpx.Response(200, json=mock_blob_put_response) + + route = respx.put(BLOB_API_BASE).mock(side_effect=handler) + + async def on_progress(event) -> None: + progress_percentages.append(event.percentage) + + await put_async( + "test.txt", + io.BytesIO(payload), + token="test_token", + on_upload_progress=on_progress, + ) + + assert route.called + assert len(progress_percentages) >= 2 + assert progress_percentages[-1] == 100.0 + assert any(percentage < 100.0 for percentage in progress_percentages) + + @respx.mock + def test_put_sync_multipart_uses_runtime_upload(self, mock_env_clear): + """Test sync multipart put uses create/upload/complete flow.""" + import json + + actions: list[str] = [] + completed_parts: list[dict[str, str | int]] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + actions.append(action) + + if action == "create": + assert request.url.params["pathname"] == "folder/put-sync.bin" + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + assert request.headers["x-mpu-upload-id"] == "upload-id" + assert request.headers["x-mpu-key"] == "blob-key" + assert request.headers["x-mpu-part-number"] == "1" + return httpx.Response(200, json={"etag": "etag-1"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response( + 200, + json={ + "url": "https://blob.vercel-storage.com/test-abc123/folder/put-sync.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/" + "test-abc123/folder/put-sync.bin?download=1" + ), + "pathname": "folder/put-sync.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="put-sync.bin"', + }, + ) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + + result = put("folder/put-sync.bin", b"hello", token="test_token", multipart=True) + + assert route.call_count == 3 + assert actions == ["create", "upload", "complete"] + assert [part["partNumber"] for part in completed_parts] == [1] + assert result.pathname == "folder/put-sync.bin" + + @respx.mock + def test_put_sync_multipart_aggregates_progress(self, mock_env_clear): + """Test sync multipart put aggregates part progress into total progress.""" + import json + + from vercel.blob.multipart.uploader import DEFAULT_PART_SIZE + + completed_parts: list[dict[str, str | int]] = [] + progress_events = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + + if action == "create": + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + part_number = int(request.headers["x-mpu-part-number"]) + return httpx.Response(200, json={"etag": f"etag-{part_number}"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response( + 200, + json={ + "url": "https://blob.vercel-storage.com/test-abc123/folder/progress-sync.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/" + "test-abc123/folder/progress-sync.bin?download=1" + ), + "pathname": "folder/progress-sync.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="progress-sync.bin"', + }, + ) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + body = (b"a" * DEFAULT_PART_SIZE) + b"b" + result = put( + "folder/progress-sync.bin", + body, + token="test_token", + multipart=True, + on_upload_progress=progress_events.append, + ) + + assert route.call_count == 4 + assert [part["partNumber"] for part in completed_parts] == [1, 2] + assert result.pathname == "folder/progress-sync.bin" + assert progress_events[-1].loaded == len(body) + assert progress_events[-1].total == len(body) + assert progress_events[-1].percentage == 100.0 + assert any(event.loaded < len(body) for event in progress_events) + + @respx.mock + @pytest.mark.asyncio + async def test_put_async_multipart_uses_runtime_upload(self, mock_env_clear): + """Test async multipart put uses create/upload/complete flow.""" + import json + + actions: list[str] = [] + completed_parts: list[dict[str, str | int]] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + actions.append(action) + + if action == "create": + assert request.url.params["pathname"] == "folder/put-async.bin" + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + assert request.headers["x-mpu-upload-id"] == "upload-id" + assert request.headers["x-mpu-key"] == "blob-key" + assert request.headers["x-mpu-part-number"] == "1" + return httpx.Response(200, json={"etag": "etag-1"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response( + 200, + json={ + "url": "https://blob.vercel-storage.com/test-abc123/folder/put-async.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/" + "test-abc123/folder/put-async.bin?download=1" + ), + "pathname": "folder/put-async.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="put-async.bin"', + }, + ) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + + result = await put_async( + "folder/put-async.bin", + b"hello", + token="test_token", + multipart=True, + ) + + assert route.call_count == 3 + assert actions == ["create", "upload", "complete"] + assert [part["partNumber"] for part in completed_parts] == [1] + assert result.pathname == "folder/put-async.bin" + + @respx.mock + @pytest.mark.asyncio + async def test_put_async_multipart_aggregates_progress(self, mock_env_clear): + """Test async multipart put aggregates part progress into total progress.""" + import json + + from vercel.blob.multipart.uploader import DEFAULT_PART_SIZE + + completed_parts: list[dict[str, str | int]] = [] + progress_events = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + + if action == "create": + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + + if action == "upload": + part_number = int(request.headers["x-mpu-part-number"]) + return httpx.Response(200, json={"etag": f"etag-{part_number}"}) + + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response( + 200, + json={ + "url": "https://blob.vercel-storage.com/test-abc123/folder/progress-async.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/" + "test-abc123/folder/progress-async.bin?download=1" + ), + "pathname": "folder/progress-async.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="progress-async.bin"', + }, + ) + + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + body = (b"a" * DEFAULT_PART_SIZE) + b"b" + + async def on_progress(event) -> None: + progress_events.append(event) + + result = await put_async( + "folder/progress-async.bin", + body, + token="test_token", + multipart=True, + on_upload_progress=on_progress, + ) + + assert route.call_count == 4 + assert [part["partNumber"] for part in completed_parts] == [1, 2] + assert result.pathname == "folder/progress-async.bin" + assert progress_events[-1].loaded == len(body) + assert progress_events[-1].total == len(body) + assert progress_events[-1].percentage == 100.0 + assert any(event.loaded < len(body) for event in progress_events) + + @respx.mock + def test_put_with_content_type(self, mock_env_clear, mock_blob_put_response): + """Test put with explicit content type.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + + result = put( + "test.json", + b'{"key": "value"}', + token="test_token", + content_type="application/json", + ) + + assert route.called + request = route.calls.last.request + assert request.headers.get("x-content-type") == "application/json" + # Verify result is properly parsed + assert result.url == mock_blob_put_response["url"] + assert result.content_type == mock_blob_put_response["contentType"] + + @respx.mock + def test_put_with_cache_control(self, mock_env_clear, mock_blob_put_response): + """Test put with cache control max age.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + + put("test.txt", b"data", token="test_token", cache_control_max_age=3600) + + assert route.called + request = route.calls.last.request + assert request.headers.get("x-cache-control-max-age") == "3600" + + @respx.mock + def test_put_sync_async_parity(self, mock_env_clear, mock_blob_put_response): + """Verify sync and async produce identical results.""" + respx.put(BLOB_API_BASE).mock(return_value=httpx.Response(200, json=mock_blob_put_response)) + + sync_result = put("test.txt", b"data", token="test_token") + + # Reset mock for async call + respx.reset() + respx.put(BLOB_API_BASE).mock(return_value=httpx.Response(200, json=mock_blob_put_response)) + + import asyncio + + async_result = asyncio.get_event_loop().run_until_complete( + put_async("test.txt", b"data", token="test_token") + ) + + assert sync_result.url == async_result.url + assert sync_result.pathname == async_result.pathname + assert sync_result.content_type == async_result.content_type + + +class TestBlobDelete: + """Test blob delete operations.""" + + @respx.mock + def test_delete_single_sync(self, mock_env_clear): + """Test synchronous single blob delete.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + delete("https://blob.vercel-storage.com/test.txt", token="test_token") + + assert route.called + request = route.calls.last.request + # Verify the URL was sent in the request body + import json + + body = json.loads(request.content) + assert "urls" in body + assert "https://blob.vercel-storage.com/test.txt" in body["urls"] + + @respx.mock + @pytest.mark.asyncio + async def test_delete_single_async(self, mock_env_clear): + """Test asynchronous single blob delete.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + await delete_async("https://blob.vercel-storage.com/test.txt", token="test_token") + + assert route.called + + @respx.mock + def test_delete_batch_sync(self, mock_env_clear): + """Test synchronous batch blob delete.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + urls = [ + "https://blob.vercel-storage.com/file1.txt", + "https://blob.vercel-storage.com/file2.txt", + "https://blob.vercel-storage.com/file3.txt", + ] + delete(urls, token="test_token") + + assert route.called + import json + + body = json.loads(route.calls.last.request.content) + assert len(body["urls"]) == 3 + + @respx.mock + def test_delete_batch_sync_generator(self, mock_env_clear): + """Test synchronous delete with generator input preserves all URLs.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + urls = (f"https://blob.vercel-storage.com/file{i}.txt" for i in range(1, 4)) + delete(urls, token="test_token") + + assert route.called + import json + + body = json.loads(route.calls.last.request.content) + assert len(body["urls"]) == 3 + + @respx.mock + @pytest.mark.asyncio + async def test_delete_batch_async(self, mock_env_clear): + """Test asynchronous batch blob delete.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + urls = [ + "https://blob.vercel-storage.com/file1.txt", + "https://blob.vercel-storage.com/file2.txt", + ] + await delete_async(urls, token="test_token") + + assert route.called + + @respx.mock + @pytest.mark.asyncio + async def test_delete_batch_async_generator(self, mock_env_clear): + """Test asynchronous delete with generator input preserves all URLs.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + urls = (f"https://blob.vercel-storage.com/file{i}.txt" for i in range(1, 3)) + await delete_async(urls, token="test_token") + + assert route.called + import json + + body = json.loads(route.calls.last.request.content) + assert len(body["urls"]) == 2 + + +class TestBlobHead: + """Test blob head/metadata operations.""" + + @respx.mock + def test_head_sync(self, mock_env_clear, mock_blob_head_response): + """Test synchronous blob metadata retrieval.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + + result = head("https://blob.vercel-storage.com/test.txt", token="test_token") + + assert route.called + assert result.size == 13 + assert result.pathname == "test.txt" + assert result.content_type == "text/plain" + + @respx.mock + @pytest.mark.asyncio + async def test_head_async(self, mock_env_clear, mock_blob_head_response): + """Test asynchronous blob metadata retrieval.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + + result = await head_async("https://blob.vercel-storage.com/test.txt", token="test_token") + + assert route.called + assert result.size == 13 + assert result.pathname == "test.txt" + + @respx.mock + def test_head_not_found(self, mock_env_clear, mock_error_not_found): + """Test 404 error handling for head operation.""" + respx.get(BLOB_API_BASE).mock(return_value=httpx.Response(404, json=mock_error_not_found)) + + with pytest.raises(BlobNotFoundError): + head("https://blob.vercel-storage.com/nonexistent.txt", token="test_token") + + @respx.mock + @pytest.mark.asyncio + async def test_head_not_found_async(self, mock_env_clear, mock_error_not_found): + """Test 404 error handling for async head operation.""" + respx.get(BLOB_API_BASE).mock(return_value=httpx.Response(404, json=mock_error_not_found)) + + with pytest.raises(BlobNotFoundError): + await head_async("https://blob.vercel-storage.com/nonexistent.txt", token="test_token") + + +class TestBlobReadAndDownload: + """Test blob read and download operations.""" + + @respx.mock + def test_get_sync_with_url(self, mock_env_clear, mock_blob_head_response): + """Test synchronous read from a direct blob URL.""" + payload = b"hello sync" + route = respx.get(mock_blob_head_response["url"]).mock( + return_value=httpx.Response(200, content=payload) + ) + + result = get(mock_blob_head_response["url"], token="test_token") + + assert route.called + assert result.content == payload + timeout = route.calls.last.request.extensions["timeout"] + assert timeout["connect"] == 30.0 + + @respx.mock + @pytest.mark.asyncio + async def test_get_async_with_path(self, mock_env_clear, mock_blob_head_response): + """Test async read resolves path metadata before fetching bytes.""" + payload = b"hello async" + head_route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + blob_route = respx.get(mock_blob_head_response["url"]).mock( + return_value=httpx.Response(200, content=payload) + ) + + result = await get_async("test.txt", token="test_token") + + assert head_route.called + assert blob_route.called + assert result.content == payload + timeout = blob_route.calls.last.request.extensions["timeout"] + assert timeout["connect"] == 30.0 + + @respx.mock + def test_download_file_sync_progress(self, mock_env_clear, mock_blob_head_response, tmp_path): + """Test sync file download writes bytes and emits progress.""" + payload = b"download-sync-payload" + route = respx.get(mock_blob_head_response["downloadUrl"]).mock( + return_value=httpx.Response( + 200, + content=payload, + headers={"Content-Length": str(len(payload))}, + ) + ) + destination = tmp_path / "sync-download.bin" + progress_updates: list[tuple[int, int | None]] = [] + + result = download_file( + mock_blob_head_response["downloadUrl"], + destination, + token="test_token", + progress=lambda loaded, total: progress_updates.append((loaded, total)), + ) + + assert route.called + assert result == str(destination) + assert destination.read_bytes() == payload + assert progress_updates[-1] == (len(payload), len(payload)) + + @respx.mock + def test_download_file_sync_progress_per_chunk( + self, mock_env_clear, mock_blob_head_response, tmp_path + ): + """Test sync file download emits progress for each streamed chunk.""" + chunks = [b"chunk-1", b"chunk-2"] + payload = b"".join(chunks) + + class ChunkedSyncStream(httpx.SyncByteStream): + def __iter__(self): + yield from chunks + + route = respx.get(mock_blob_head_response["downloadUrl"]).mock( + return_value=httpx.Response( + 200, + stream=ChunkedSyncStream(), + headers={"Content-Length": str(len(payload))}, + ) + ) + destination = tmp_path / "sync-download-chunked.bin" + progress_updates: list[tuple[int, int | None]] = [] + + result = download_file( + mock_blob_head_response["downloadUrl"], + destination, + token="test_token", + progress=lambda loaded, total: progress_updates.append((loaded, total)), + ) + + assert route.called + assert result == str(destination) + assert destination.read_bytes() == payload + assert len(progress_updates) >= 2 + assert progress_updates[-1] == (len(payload), len(payload)) + assert any(update[0] < len(payload) for update in progress_updates) + + @respx.mock + @pytest.mark.asyncio + async def test_download_file_async_progress( + self, mock_env_clear, mock_blob_head_response, tmp_path + ): + """Test async file download supports awaitable progress callbacks.""" + payload = b"download-async-payload" + head_route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + download_route = respx.get(mock_blob_head_response["downloadUrl"]).mock( + return_value=httpx.Response( + 200, + content=payload, + headers={"Content-Length": str(len(payload))}, + ) + ) + destination = tmp_path / "async-download.bin" + progress_updates: list[tuple[int, int | None]] = [] + + async def progress_callback(loaded: int, total: int | None) -> None: + progress_updates.append((loaded, total)) + + result = await download_file_async( + "test.txt", + destination, + token="test_token", + progress=progress_callback, + ) + + assert head_route.called + assert download_route.called + assert result == str(destination) + assert destination.read_bytes() == payload + assert progress_updates[-1] == (len(payload), len(payload)) + + @respx.mock + @pytest.mark.asyncio + async def test_download_file_async_progress_per_chunk( + self, mock_env_clear, mock_blob_head_response, tmp_path + ): + """Test async file download emits progress for each streamed chunk.""" + chunks = [b"chunk-a", b"chunk-b"] + payload = b"".join(chunks) + + class ChunkedAsyncStream(httpx.AsyncByteStream): + async def __aiter__(self): + for chunk in chunks: + yield chunk + + head_route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + download_route = respx.get(mock_blob_head_response["downloadUrl"]).mock( + return_value=httpx.Response( + 200, + stream=ChunkedAsyncStream(), + headers={"Content-Length": str(len(payload))}, + ) + ) + destination = tmp_path / "async-download-chunked.bin" + progress_updates: list[tuple[int, int | None]] = [] + + async def progress_callback(loaded: int, total: int | None) -> None: + progress_updates.append((loaded, total)) + + result = await download_file_async( + "test.txt", + destination, + token="test_token", + progress=progress_callback, + ) + + assert head_route.called + assert download_route.called + assert result == str(destination) + assert destination.read_bytes() == payload + assert len(progress_updates) >= 2 + assert progress_updates[-1] == (len(payload), len(payload)) + assert any(update[0] < len(payload) for update in progress_updates) + + +class TestBlobList: + """Test blob list operations.""" + + @respx.mock + def test_list_objects_sync(self, mock_env_clear, mock_blob_list_response): + """Test synchronous blob listing.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + result = list_objects(token="test_token") + + assert route.called + assert len(result.blobs) == 2 + assert result.blobs[0].pathname == "file1.txt" + assert result.blobs[1].pathname == "file2.txt" + assert result.has_more is False + + @respx.mock + @pytest.mark.asyncio + async def test_list_objects_async(self, mock_env_clear, mock_blob_list_response): + """Test asynchronous blob listing.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + result = await list_objects_async(token="test_token") + + assert route.called + assert len(result.blobs) == 2 + + @respx.mock + def test_list_objects_with_prefix(self, mock_env_clear, mock_blob_list_response): + """Test list with prefix filter.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + list_objects(prefix="files/", token="test_token") + + assert route.called + request = route.calls.last.request + assert "prefix=files%2F" in str(request.url) or "prefix=files/" in str(request.url) + + @respx.mock + def test_list_objects_with_limit(self, mock_env_clear, mock_blob_list_response): + """Test list with limit parameter.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + list_objects(limit=10, token="test_token") + + assert route.called + request = route.calls.last.request + assert "limit=10" in str(request.url) + + @respx.mock + def test_list_objects_pagination(self, mock_env_clear, mock_blob_list_response_paginated): + """Test pagination with cursor.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response_paginated) + ) + + result = list_objects(token="test_token") + + assert route.called + assert result.has_more is True + assert result.cursor == "next_cursor_abc123" + + +class TestBlobIterObjects: + """Test blob iteration operations.""" + + @respx.mock + def test_iter_objects_sync(self, mock_env_clear, mock_blob_list_response): + """Test synchronous blob iteration.""" + respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + items = list(iter_objects(token="test_token")) + + assert len(items) == 2 + assert items[0].pathname == "file1.txt" + assert items[1].pathname == "file2.txt" + + @respx.mock + @pytest.mark.asyncio + async def test_iter_objects_async(self, mock_env_clear, mock_blob_list_response): + """Test asynchronous blob iteration.""" + respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + items = [] + async for item in iter_objects_async(token="test_token"): + items.append(item) + + assert len(items) == 2 + + @respx.mock + def test_iter_objects_with_limit(self, mock_env_clear, mock_blob_list_response): + """Test iteration with limit.""" + respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + items = list(iter_objects(limit=1, token="test_token")) + + assert len(items) == 1 + + @respx.mock + def test_iter_objects_pagination( + self, mock_env_clear, mock_blob_list_response_paginated, mock_blob_list_response + ): + """Test iteration across multiple pages.""" + # First call returns paginated response + respx.get(BLOB_API_BASE).mock( + side_effect=[ + httpx.Response(200, json=mock_blob_list_response_paginated), + httpx.Response(200, json=mock_blob_list_response), + ] + ) + + items = list(iter_objects(token="test_token")) + + # Should have items from both pages (1 + 2 = 3) + assert len(items) == 3 + + @respx.mock + def test_iter_objects_sync_batch_size_limit_pagination(self, mock_env_clear): + """Test sync iteration uses paginated list requests with limit-aware batching.""" + first_page = { + "blobs": [ + { + "url": "https://blob.vercel-storage.com/test-abc123/page1.txt", + "downloadUrl": "https://blob.vercel-storage.com/test-abc123/page1.txt?download=1", + "pathname": "page1.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="page1.txt"', + "size": 50, + "uploadedAt": "2024-01-15T10:30:00.000Z", + } + ], + "cursor": "cursor-1", + "hasMore": True, + "folders": [], + } + second_page = { + "blobs": [ + { + "url": "https://blob.vercel-storage.com/test-abc123/page2.txt", + "downloadUrl": "https://blob.vercel-storage.com/test-abc123/page2.txt?download=1", + "pathname": "page2.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="page2.txt"', + "size": 60, + "uploadedAt": "2024-01-15T10:31:00.000Z", + } + ], + "cursor": "cursor-2", + "hasMore": True, + "folders": [], + } + route = respx.get(BLOB_API_BASE).mock( + side_effect=[ + httpx.Response(200, json=first_page), + httpx.Response(200, json=second_page), + ] + ) + + items = list(iter_objects(batch_size=1, limit=2, token="test_token")) + + assert [item.pathname for item in items] == ["page1.txt", "page2.txt"] + assert route.call_count == 2 + assert route.calls[0].request.url.params.get("limit") == "1" + assert route.calls[0].request.url.params.get("cursor") is None + assert route.calls[1].request.url.params.get("limit") == "1" + assert route.calls[1].request.url.params.get("cursor") == "cursor-1" + + @respx.mock + @pytest.mark.asyncio + async def test_iter_objects_async_batch_size_limit_pagination(self, mock_env_clear): + """Test async iteration uses paginated list requests with limit-aware batching.""" + first_page = { + "blobs": [ + { + "url": "https://blob.vercel-storage.com/test-abc123/page1.txt", + "downloadUrl": "https://blob.vercel-storage.com/test-abc123/page1.txt?download=1", + "pathname": "page1.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="page1.txt"', + "size": 50, + "uploadedAt": "2024-01-15T10:30:00.000Z", + } + ], + "cursor": "cursor-1", + "hasMore": True, + "folders": [], + } + second_page = { + "blobs": [ + { + "url": "https://blob.vercel-storage.com/test-abc123/page2.txt", + "downloadUrl": "https://blob.vercel-storage.com/test-abc123/page2.txt?download=1", + "pathname": "page2.txt", + "contentType": "text/plain", + "contentDisposition": 'inline; filename="page2.txt"', + "size": 60, + "uploadedAt": "2024-01-15T10:31:00.000Z", + } + ], + "cursor": "cursor-2", + "hasMore": True, + "folders": [], + } + route = respx.get(BLOB_API_BASE).mock( + side_effect=[ + httpx.Response(200, json=first_page), + httpx.Response(200, json=second_page), + ] + ) + + items = [] + async for item in iter_objects_async(batch_size=1, limit=2, token="test_token"): + items.append(item) + + assert [item.pathname for item in items] == ["page1.txt", "page2.txt"] + assert route.call_count == 2 + assert route.calls[0].request.url.params.get("limit") == "1" + assert route.calls[0].request.url.params.get("cursor") is None + assert route.calls[1].request.url.params.get("limit") == "1" + assert route.calls[1].request.url.params.get("cursor") == "cursor-1" + + +class TestBlobCopy: + """Test blob copy operations.""" + + @respx.mock + def test_copy_sync(self, mock_env_clear, mock_blob_copy_response): + """Test synchronous blob copy.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_copy_response) + ) + + result = copy( + "https://blob.vercel-storage.com/source.txt", + "copied.txt", + token="test_token", + ) + + assert route.called + assert result.pathname == "copied.txt" + + # Verify fromUrl parameter was sent + request = route.calls.last.request + assert "fromUrl" in str(request.url) + + @respx.mock + @pytest.mark.asyncio + async def test_copy_async(self, mock_env_clear, mock_blob_copy_response): + """Test asynchronous blob copy.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_copy_response) + ) + + result = await copy_async( + "https://blob.vercel-storage.com/source.txt", + "copied.txt", + token="test_token", + ) + + assert route.called + assert result.pathname == "copied.txt" + + +class TestBlobUploadFile: + """Test upload_file helpers.""" + + @respx.mock + def test_upload_file_sync(self, mock_env_clear, mock_blob_put_response, tmp_path): + """Test synchronous upload_file helper.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + local_file = tmp_path / "upload-sync.txt" + local_file.write_bytes(b"sync file upload") + + result = upload_file(local_file, "test.txt", token="test_token") + + assert route.called + assert result.url == mock_blob_put_response["url"] + + @respx.mock + @pytest.mark.asyncio + async def test_upload_file_async(self, mock_env_clear, mock_blob_put_response, tmp_path): + """Test asynchronous upload_file helper.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + local_file = tmp_path / "upload-async.txt" + local_file.write_bytes(b"async file upload") + + result = await upload_file_async(local_file, "test.txt", token="test_token") + + assert route.called + assert result.url == mock_blob_put_response["url"] + + +class TestBlobCreateFolder: + """Test blob folder creation.""" + + @respx.mock + def test_create_folder_sync(self, mock_env_clear, mock_blob_create_folder_response): + """Test synchronous folder creation.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_create_folder_response) + ) + + result = create_folder("my-folder", token="test_token") + + assert route.called + assert result.pathname == "my-folder/" + + # Verify pathname ends with / + request = route.calls.last.request + assert "pathname=my-folder%2F" in str(request.url) or "pathname=my-folder/" in str( + request.url + ) + + @respx.mock + @pytest.mark.asyncio + async def test_create_folder_async(self, mock_env_clear, mock_blob_create_folder_response): + """Test asynchronous folder creation.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_create_folder_response) + ) + + result = await create_folder_async("my-folder", token="test_token") + + assert route.called + assert result.pathname == "my-folder/" + + +class TestBlobClient: + """Test BlobClient and AsyncBlobClient classes.""" + + @respx.mock + def test_blob_client_put(self, mock_env_clear, mock_blob_put_response): + """Test BlobClient put method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + + client = BlobClient(token="test_token") + result = client.put("test.txt", b"Hello, World!") + + assert route.called + assert result.url == mock_blob_put_response["url"] + + @respx.mock + def test_blob_client_head(self, mock_env_clear, mock_blob_head_response): + """Test BlobClient head method.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + + client = BlobClient(token="test_token") + result = client.head("https://blob.vercel-storage.com/test.txt") + + assert route.called + assert result.size == 13 + + @respx.mock + def test_blob_client_delete(self, mock_env_clear): + """Test BlobClient delete method.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + client = BlobClient(token="test_token") + client.delete("https://blob.vercel-storage.com/test.txt") + + assert route.called + + @respx.mock + def test_blob_client_list_objects(self, mock_env_clear, mock_blob_list_response): + """Test BlobClient list_objects method.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + client = BlobClient(token="test_token") + result = client.list_objects() + + assert route.called + assert len(result.blobs) == 2 + + @respx.mock + def test_blob_client_get(self, mock_env_clear): + """Test BlobClient get method.""" + route = respx.get("https://blob.vercel-storage.com/test.txt").mock( + return_value=httpx.Response(200, content=b"blob data") + ) + + client = BlobClient(token="test_token") + result = client.get("https://blob.vercel-storage.com/test.txt") + + assert route.called + assert result.content == b"blob data" + + @respx.mock + def test_blob_client_iter_objects(self, mock_env_clear, mock_blob_list_response): + """Test BlobClient iter_objects method.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + client = BlobClient(token="test_token") + items = list(client.iter_objects()) + + assert route.called + assert len(items) == 2 + + @respx.mock + def test_blob_client_copy(self, mock_env_clear, mock_blob_copy_response): + """Test BlobClient copy method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_copy_response) + ) + + client = BlobClient(token="test_token") + result = client.copy("https://blob.vercel-storage.com/source.txt", "copied.txt") + + assert route.called + assert result.pathname == "copied.txt" + + @respx.mock + def test_blob_client_create_folder(self, mock_env_clear, mock_blob_create_folder_response): + """Test BlobClient create_folder method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_create_folder_response) + ) + + client = BlobClient(token="test_token") + result = client.create_folder("my-folder") + + assert route.called + assert result.pathname == "my-folder/" + + @respx.mock + def test_blob_client_download_file(self, mock_env_clear, tmp_path): + """Test BlobClient download_file method.""" + route = respx.get("https://blob.vercel-storage.com/download.txt").mock( + return_value=httpx.Response(200, content=b"downloaded") + ) + local_path = tmp_path / "downloaded.txt" + + client = BlobClient(token="test_token") + result = client.download_file("https://blob.vercel-storage.com/download.txt", local_path) + + assert route.called + assert result == str(local_path) + assert local_path.read_bytes() == b"downloaded" + + @respx.mock + def test_blob_client_upload_file(self, mock_env_clear, mock_blob_put_response, tmp_path): + """Test BlobClient upload_file method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + local_file = tmp_path / "client-upload.txt" + local_file.write_bytes(b"client upload") + + client = BlobClient(token="test_token") + result = client.upload_file(local_file, "test.txt") + + assert route.called + assert result.url == mock_blob_put_response["url"] + + @respx.mock + def test_blob_client_create_multipart_uploader(self, mock_env_clear): + """Test BlobClient create_multipart_uploader method.""" + import json + + completed_parts: list[dict[str, str | int]] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + if action == "create": + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + if action == "upload": + return httpx.Response(200, json={"etag": "etag-1"}) + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response( + 200, + json={ + "url": "https://blob.vercel-storage.com/test-abc123/folder/client-mpu.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/" + "test-abc123/folder/client-mpu.bin?download=1" + ), + "pathname": "folder/client-mpu.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="client-mpu.bin"', + }, + ) + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + client = BlobClient(token="test_token") + + uploader = client.create_multipart_uploader("folder/client-mpu.bin") + part = uploader.upload_part(1, b"chunk") + result = uploader.complete([part]) + + assert route.call_count == 3 + assert [part["partNumber"] for part in completed_parts] == [1] + assert result.pathname == "folder/client-mpu.bin" + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_put(self, mock_env_clear, mock_blob_put_response): + """Test AsyncBlobClient put method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + + client = AsyncBlobClient(token="test_token") + result = await client.put("test.txt", b"Hello, World!") + + assert route.called + assert result.url == mock_blob_put_response["url"] + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_head(self, mock_env_clear, mock_blob_head_response): + """Test AsyncBlobClient head method.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_head_response) + ) + + client = AsyncBlobClient(token="test_token") + result = await client.head("https://blob.vercel-storage.com/test.txt") + + assert route.called + assert result.size == 13 + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_get(self, mock_env_clear): + """Test AsyncBlobClient get method.""" + route = respx.get("https://blob.vercel-storage.com/test.txt").mock( + return_value=httpx.Response(200, content=b"blob data") + ) + + client = AsyncBlobClient(token="test_token") + result = await client.get("https://blob.vercel-storage.com/test.txt") + + assert route.called + assert result.content == b"blob data" + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_list_objects(self, mock_env_clear, mock_blob_list_response): + """Test AsyncBlobClient list_objects method.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + client = AsyncBlobClient(token="test_token") + result = await client.list_objects() + + assert route.called + assert len(result.blobs) == 2 + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_iter_objects(self, mock_env_clear, mock_blob_list_response): + """Test AsyncBlobClient iter_objects method.""" + route = respx.get(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_list_response) + ) + + client = AsyncBlobClient(token="test_token") + iterator = await client.iter_objects() + items = [item async for item in iterator] + + assert route.called + assert len(items) == 2 + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_delete(self, mock_env_clear): + """Test AsyncBlobClient delete method.""" + route = respx.post(f"{BLOB_API_BASE}/delete").mock( + return_value=httpx.Response(200, json={}) + ) + + client = AsyncBlobClient(token="test_token") + await client.delete("https://blob.vercel-storage.com/test.txt") + + assert route.called + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_copy(self, mock_env_clear, mock_blob_copy_response): + """Test AsyncBlobClient copy method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_copy_response) + ) + + client = AsyncBlobClient(token="test_token") + result = await client.copy("https://blob.vercel-storage.com/source.txt", "copied.txt") + + assert route.called + assert result.pathname == "copied.txt" + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_create_folder( + self, mock_env_clear, mock_blob_create_folder_response + ): + """Test AsyncBlobClient create_folder method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_create_folder_response) + ) + + client = AsyncBlobClient(token="test_token") + result = await client.create_folder("my-folder") + + assert route.called + assert result.pathname == "my-folder/" + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_download_file(self, mock_env_clear, tmp_path): + """Test AsyncBlobClient download_file method.""" + route = respx.get("https://blob.vercel-storage.com/download.txt").mock( + return_value=httpx.Response(200, content=b"downloaded") + ) + local_path = tmp_path / "downloaded-async.txt" + + client = AsyncBlobClient(token="test_token") + result = await client.download_file( + "https://blob.vercel-storage.com/download.txt", local_path + ) + + assert route.called + assert result == str(local_path) + assert local_path.read_bytes() == b"downloaded" + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_upload_file( + self, mock_env_clear, mock_blob_put_response, tmp_path + ): + """Test AsyncBlobClient upload_file method.""" + route = respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response(200, json=mock_blob_put_response) + ) + local_file = tmp_path / "client-upload-async.txt" + local_file.write_bytes(b"async client upload") + + client = AsyncBlobClient(token="test_token") + result = await client.upload_file(local_file, "test.txt") + + assert route.called + assert result.url == mock_blob_put_response["url"] + + @respx.mock + @pytest.mark.asyncio + async def test_async_blob_client_create_multipart_uploader(self, mock_env_clear): + """Test AsyncBlobClient create_multipart_uploader method.""" + import json + + completed_parts: list[dict[str, str | int]] = [] + + def mpu_handler(request: httpx.Request) -> httpx.Response: + action = request.headers["x-mpu-action"] + if action == "create": + return httpx.Response(200, json={"uploadId": "upload-id", "key": "blob-key"}) + if action == "upload": + return httpx.Response(200, json={"etag": "etag-1"}) + if action == "complete": + completed_parts.extend(json.loads(request.content.decode())) + return httpx.Response( + 200, + json={ + "url": ( + "https://blob.vercel-storage.com/test-abc123/folder/" + "client-mpu-async.bin" + ), + "downloadUrl": ( + "https://blob.vercel-storage.com/" + "test-abc123/folder/client-mpu-async.bin?download=1" + ), + "pathname": "folder/client-mpu-async.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="client-mpu-async.bin"', + }, + ) + raise AssertionError(f"unexpected multipart action: {action}") + + route = respx.post(f"{BLOB_API_BASE}/mpu").mock(side_effect=mpu_handler) + client = AsyncBlobClient(token="test_token") + + uploader = await client.create_multipart_uploader("folder/client-mpu-async.bin") + part = await uploader.upload_part(1, b"chunk") + result = await uploader.complete([part]) + + assert route.call_count == 3 + assert [part["partNumber"] for part in completed_parts] == [1] + assert result.pathname == "folder/client-mpu-async.bin" + + +class TestBlobErrorHandling: + """Test error handling for blob operations.""" + + @respx.mock + def test_missing_token_raises_error(self, mock_env_clear): + """Test that missing token raises BlobError.""" + from vercel.blob.errors import BlobError + + # Don't mock any routes - we expect failure before HTTP call + with pytest.raises(BlobError): + put("test.txt", b"data") + + @respx.mock + def test_not_found_error(self, mock_env_clear, mock_error_not_found): + """Test BlobNotFoundError is raised on 404.""" + respx.get(BLOB_API_BASE).mock(return_value=httpx.Response(404, json=mock_error_not_found)) + + with pytest.raises(BlobNotFoundError): + head("https://blob.vercel-storage.com/missing.txt", token="test_token") + + @respx.mock + def test_access_error(self, mock_env_clear, mock_error_forbidden): + """Test BlobAccessError is raised on 403.""" + from vercel.blob import BlobAccessError + + respx.get(BLOB_API_BASE).mock(return_value=httpx.Response(403, json=mock_error_forbidden)) + + with pytest.raises(BlobAccessError): + head("https://blob.vercel-storage.com/forbidden.txt", token="test_token") + + @respx.mock + @pytest.mark.asyncio + async def test_not_found_error_async(self, mock_env_clear, mock_error_not_found): + """Test BlobNotFoundError is raised on 404 for async.""" + respx.get(BLOB_API_BASE).mock(return_value=httpx.Response(404, json=mock_error_not_found)) + + with pytest.raises(BlobNotFoundError): + await head_async("https://blob.vercel-storage.com/missing.txt", token="test_token") + + @respx.mock + def test_put_invalid_json_raises_blob_error(self, mock_env_clear): + """Test malformed 2xx JSON response raises a dedicated parse error.""" + from vercel.blob.errors import BlobInvalidResponseJSONError + + respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response( + 200, + headers={"content-type": "application/json"}, + content=b'{"url":', + ) + ) + + with pytest.raises(BlobInvalidResponseJSONError, match=r"parse JSON response body"): + put("test.txt", b"data", token="test_token") + + @respx.mock + def test_put_unexpected_content_type_raises_blob_error(self, mock_env_clear): + """Test non-JSON content type on 2xx response raises content-type error.""" + from vercel.blob.errors import BlobUnexpectedResponseContentTypeError + + respx.put(BLOB_API_BASE).mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/plain"}, + content=b'{"url":"https://blob.vercel-storage.com/test.txt"}', + ) + ) + + with pytest.raises( + BlobUnexpectedResponseContentTypeError, + match=r"Unexpected response content type: text/plain", + ): + put("test.txt", b"data", token="test_token") + + def test_decode_blob_response_json_accepts_non_object_json(self): + """Test strict JSON decode mode accepts valid non-object JSON values.""" + response = httpx.Response( + 200, + headers={"content-type": "application/json"}, + content=b"[1, 2, 3]", + ) + assert decode_blob_response_json(response) == [1, 2, 3] + + +class TestBlobPublicHelpers: + """Test public non-network blob helpers.""" + + def test_get_download_url_preserves_existing_query(self): + """Test get_download_url adds download=1 while preserving existing query params.""" + url = "https://blob.vercel-storage.com/test-abc123/file.txt?foo=bar" + download_url = get_download_url(url) + + assert "foo=bar" in download_url + assert "download=1" in download_url + + def test_aioblob_module_alias_exports_async_api(self): + """Test aioblob alias exposes async API module.""" + assert hasattr(aioblob, "put") + assert hasattr(aioblob, "create_multipart_upload") diff --git a/tests/integration/test_cache_sync_async.py b/tests/integration/test_cache_sync_async.py new file mode 100644 index 0000000..794a88a --- /dev/null +++ b/tests/integration/test_cache_sync_async.py @@ -0,0 +1,365 @@ +"""Integration tests for Vercel Cache API. + +Tests RuntimeCache and AsyncRuntimeCache using the in-memory fallback +when cache environment variables are not set. +""" + +import pytest + + +class TestRuntimeCacheInMemory: + """Test RuntimeCache with in-memory fallback.""" + + def test_get_cache_returns_runtime_cache(self, mock_env_clear): + """Test get_cache returns a RuntimeCache instance.""" + from vercel.cache import RuntimeCache, get_cache + + cache = get_cache() + assert isinstance(cache, RuntimeCache) + + def test_get_set_delete_sync(self, mock_env_clear): + """Test basic get/set/delete operations (sync).""" + from vercel.cache import get_cache + + cache = get_cache() + + # Set a value + cache.set("test_key", "test_value") + + # Get the value + result = cache.get("test_key") + assert result == "test_value" + + # Delete the value + cache.delete("test_key") + + # Get should return None after delete + result = cache.get("test_key") + assert result is None + + def test_set_with_dict_value(self, mock_env_clear): + """Test setting dict values.""" + from vercel.cache import get_cache + + cache = get_cache() + + test_data = {"name": "test", "count": 42, "nested": {"key": "value"}} + cache.set("dict_key", test_data) + + result = cache.get("dict_key") + assert result == test_data + + cache.delete("dict_key") + + def test_set_with_list_value(self, mock_env_clear): + """Test setting list values.""" + from vercel.cache import get_cache + + cache = get_cache() + + test_list = [1, 2, 3, "four", {"five": 5}] + cache.set("list_key", test_list) + + result = cache.get("list_key") + assert result == test_list + + cache.delete("list_key") + + def test_contains_operator(self, mock_env_clear): + """Test __contains__ for 'in' operator.""" + from vercel.cache import get_cache + + cache = get_cache() + + cache.set("exists_key", "value") + + assert "exists_key" in cache + assert "nonexistent_key" not in cache + + cache.delete("exists_key") + + def test_getitem_operator(self, mock_env_clear): + """Test __getitem__ for bracket notation.""" + from vercel.cache import get_cache + + cache = get_cache() + + cache.set("bracket_key", "bracket_value") + + assert cache["bracket_key"] == "bracket_value" + + # Should raise KeyError for missing keys + with pytest.raises(KeyError): + _ = cache["missing_key"] + + cache.delete("bracket_key") + + def test_namespace_option(self, mock_env_clear): + """Test namespace option for key prefixing.""" + from vercel.cache import get_cache + + cache1 = get_cache(namespace="ns1") + cache2 = get_cache(namespace="ns2") + + # Set same key in different namespaces + cache1.set("shared_key", "value1") + cache2.set("shared_key", "value2") + + # Each should get their own value + assert cache1.get("shared_key") == "value1" + assert cache2.get("shared_key") == "value2" + + # Cleanup + cache1.delete("shared_key") + cache2.delete("shared_key") + + def test_namespace_separator_option(self, mock_env_clear): + """Test custom namespace separator.""" + from vercel.cache import get_cache + + cache = get_cache(namespace="myns", namespace_separator="::") + + cache.set("key", "value") + # The key transformation should use the custom separator + # The value should still be retrievable + assert cache.get("key") == "value" + + cache.delete("key") + + +class TestAsyncRuntimeCacheInMemory: + """Test AsyncRuntimeCache with in-memory fallback.""" + + @pytest.mark.asyncio + async def test_get_set_delete_async(self, mock_env_clear): + """Test basic get/set/delete operations (async).""" + from vercel.cache import AsyncRuntimeCache + + cache = AsyncRuntimeCache() + + # Set a value + await cache.set("async_key", "async_value") + + # Get the value + result = await cache.get("async_key") + assert result == "async_value" + + # Delete the value + await cache.delete("async_key") + + # Get should return None after delete + result = await cache.get("async_key") + assert result is None + + @pytest.mark.asyncio + async def test_async_dict_value(self, mock_env_clear): + """Test async setting dict values.""" + from vercel.cache import AsyncRuntimeCache + + cache = AsyncRuntimeCache() + + test_data = {"async": True, "data": [1, 2, 3]} + await cache.set("async_dict", test_data) + + result = await cache.get("async_dict") + assert result == test_data + + await cache.delete("async_dict") + + @pytest.mark.asyncio + async def test_async_contains(self, mock_env_clear): + """Test async contains method.""" + from vercel.cache import AsyncRuntimeCache + + cache = AsyncRuntimeCache() + + await cache.set("async_exists", "value") + + assert await cache.contains("async_exists") is True + assert await cache.contains("async_missing") is False + + await cache.delete("async_exists") + + @pytest.mark.asyncio + async def test_async_namespace(self, mock_env_clear): + """Test async cache with namespace.""" + from vercel.cache import AsyncRuntimeCache + + cache1 = AsyncRuntimeCache(namespace="async_ns1") + cache2 = AsyncRuntimeCache(namespace="async_ns2") + + await cache1.set("key", "value1") + await cache2.set("key", "value2") + + assert await cache1.get("key") == "value1" + assert await cache2.get("key") == "value2" + + await cache1.delete("key") + await cache2.delete("key") + + +class TestSyncAsyncParity: + """Test that sync and async caches produce consistent results.""" + + @pytest.mark.asyncio + async def test_sync_async_share_in_memory_store(self, mock_env_clear): + """Test that sync and async caches share the same in-memory store.""" + from vercel.cache import AsyncRuntimeCache, get_cache + + sync_cache = get_cache() + async_cache = AsyncRuntimeCache() + + # Set with sync + sync_cache.set("shared_store_key", "sync_value") + + # Get with async - should see the value + result = await async_cache.get("shared_store_key") + assert result == "sync_value" + + # Delete with async + await async_cache.delete("shared_store_key") + + # Sync should see it's gone + assert sync_cache.get("shared_store_key") is None + + +class TestCacheTagOperations: + """Test cache tag-based operations.""" + + def test_expire_tag_sync(self, mock_env_clear): + """Test expiring cache entries by tag (sync).""" + from vercel.cache import get_cache + + cache = get_cache() + + # Set values with tags + cache.set("tagged1", "value1", {"tags": ["tag1", "tag2"]}) + cache.set("tagged2", "value2", {"tags": ["tag1"]}) + cache.set("untagged", "value3") # No tag1 + + # Verify entries exist before expiring + assert cache.get("tagged1") == "value1" + assert cache.get("tagged2") == "value2" + assert cache.get("untagged") == "value3" + + # Expire by tag - should remove entries with tag1 + cache.expire_tag("tag1") + + # Entries with tag1 should be expired + assert cache.get("tagged1") is None + assert cache.get("tagged2") is None + # Entry without tag1 should still exist + assert cache.get("untagged") == "value3" + + def test_expire_tag_list(self, mock_env_clear): + """Test expiring cache entries by multiple tags.""" + from vercel.cache import get_cache + + cache = get_cache() + + cache.set("tagged_a", "value_a", {"tags": ["tag_a"]}) + cache.set("tagged_b", "value_b", {"tags": ["tag_b"]}) + cache.set("tagged_both", "value_both", {"tags": ["tag_a", "tag_b"]}) + + # Verify entries exist + assert cache.get("tagged_a") == "value_a" + assert cache.get("tagged_b") == "value_b" + assert cache.get("tagged_both") == "value_both" + + # Expire by multiple tags + cache.expire_tag(["tag_a", "tag_b"]) + + # All tagged entries should be expired + assert cache.get("tagged_a") is None + assert cache.get("tagged_b") is None + assert cache.get("tagged_both") is None + + @pytest.mark.asyncio + async def test_expire_tag_async(self, mock_env_clear): + """Test expiring cache entries by tag (async).""" + from vercel.cache import AsyncRuntimeCache + + cache = AsyncRuntimeCache() + + await cache.set("async_tagged", "value", {"tags": ["async_tag"]}) + await cache.set("async_untagged", "other_value") + + # Verify entry exists + assert await cache.get("async_tagged") == "value" + assert await cache.get("async_untagged") == "other_value" + + await cache.expire_tag("async_tag") + + # Tagged entry should be expired + assert await cache.get("async_tagged") is None + # Untagged entry should still exist + assert await cache.get("async_untagged") == "other_value" + + +class TestCacheWithOptions: + """Test cache operations with options.""" + + def test_set_with_ttl_option(self, mock_env_clear): + """Test setting cache with TTL option.""" + from vercel.cache import get_cache + + cache = get_cache() + + # Set with TTL (time to live) + cache.set("ttl_key", "ttl_value", {"ttl": 60}) # 60 seconds + + # Value should be retrievable immediately + result = cache.get("ttl_key") + assert result == "ttl_value" + + cache.delete("ttl_key") + + def test_set_with_tags_option(self, mock_env_clear): + """Test setting cache with tags option.""" + from vercel.cache import get_cache + + cache = get_cache() + + # Set with tags + cache.set("tags_key", "tags_value", {"tags": ["category:test", "type:demo"]}) + + result = cache.get("tags_key") + assert result == "tags_value" + + cache.delete("tags_key") + + +class TestCacheKeyTransformation: + """Test cache key transformation with hash function.""" + + def test_custom_hash_function(self, mock_env_clear): + """Test custom key hash function.""" + import hashlib + + from vercel.cache import get_cache + + def custom_hash(key: str) -> str: + return hashlib.md5(key.encode()).hexdigest() + + cache = get_cache(key_hash_function=custom_hash) + + cache.set("original_key", "value") + result = cache.get("original_key") + assert result == "value" + + cache.delete("original_key") + + def test_namespace_with_hash_function(self, mock_env_clear): + """Test namespace combined with hash function.""" + from vercel.cache import get_cache + + def simple_hash(key: str) -> str: + return f"hashed_{key}" + + cache = get_cache(namespace="ns", key_hash_function=simple_hash) + + cache.set("key", "value") + result = cache.get("key") + assert result == "value" + + cache.delete("key") diff --git a/tests/integration/test_functions_sync_async.py b/tests/integration/test_functions_sync_async.py new file mode 100644 index 0000000..daae1cf --- /dev/null +++ b/tests/integration/test_functions_sync_async.py @@ -0,0 +1,339 @@ +"""Integration tests for Vercel Functions module. + +Tests environment access, IP/geo extraction, and header context management. +""" + +import pytest + + +class TestGetEnv: + """Test get_env and Env dataclass.""" + + def test_get_env_from_os_environ(self, mock_env_clear, monkeypatch): + """Test get_env reads from os.environ.""" + from vercel.functions import Env, get_env + + monkeypatch.setenv("VERCEL", "1") + monkeypatch.setenv("VERCEL_ENV", "production") + monkeypatch.setenv("VERCEL_URL", "my-app.vercel.app") + monkeypatch.setenv("VERCEL_REGION", "iad1") + + env = get_env() + + assert isinstance(env, Env) + assert env.VERCEL == "1" + assert env.VERCEL_ENV == "production" + assert env.VERCEL_URL == "my-app.vercel.app" + assert env.VERCEL_REGION == "iad1" + + def test_get_env_from_custom_mapping(self, mock_env_clear): + """Test get_env reads from custom mapping.""" + from vercel.functions import get_env + + custom_env = { + "VERCEL": "1", + "VERCEL_ENV": "preview", + "VERCEL_DEPLOYMENT_ID": "dpl_test123", + } + + env = get_env(custom_env) + + assert env.VERCEL == "1" + assert env.VERCEL_ENV == "preview" + assert env.VERCEL_DEPLOYMENT_ID == "dpl_test123" + + def test_get_env_normalizes_empty_strings(self, mock_env_clear, monkeypatch): + """Test that empty strings are normalized to None.""" + from vercel.functions import get_env + + monkeypatch.setenv("VERCEL", "1") + monkeypatch.setenv("VERCEL_URL", "") # Empty string + + env = get_env() + + assert env.VERCEL == "1" + assert env.VERCEL_URL is None + + def test_env_to_dict(self, mock_env_clear, monkeypatch): + """Test Env.to_dict method.""" + from vercel.functions import get_env + + monkeypatch.setenv("VERCEL", "1") + monkeypatch.setenv("CI", "true") + + env = get_env() + env_dict = env.to_dict() + + assert isinstance(env_dict, dict) + assert env_dict["VERCEL"] == "1" + assert env_dict["CI"] == "true" + + def test_env_getitem(self, mock_env_clear, monkeypatch): + """Test Env bracket notation access.""" + from vercel.functions import get_env + + monkeypatch.setenv("VERCEL_ENV", "development") + + env = get_env() + + assert env["VERCEL_ENV"] == "development" + + with pytest.raises(KeyError): + _ = env["NONEXISTENT_KEY"] + + def test_env_get_with_default(self, mock_env_clear, monkeypatch): + """Test Env.get method with default.""" + from vercel.functions import get_env + + monkeypatch.setenv("VERCEL", "1") + + env = get_env() + + assert env.get("VERCEL") == "1" + assert env.get("NONEXISTENT", "default") == "default" + assert env.get("VERCEL_URL") is None + + def test_env_git_fields(self, mock_env_clear, monkeypatch): + """Test Git-related environment fields.""" + from vercel.functions import get_env + + monkeypatch.setenv("VERCEL_GIT_PROVIDER", "github") + monkeypatch.setenv("VERCEL_GIT_REPO_SLUG", "my-repo") + monkeypatch.setenv("VERCEL_GIT_REPO_OWNER", "my-org") + monkeypatch.setenv("VERCEL_GIT_COMMIT_REF", "main") + monkeypatch.setenv("VERCEL_GIT_COMMIT_SHA", "abc123") + + env = get_env() + + assert env.VERCEL_GIT_PROVIDER == "github" + assert env.VERCEL_GIT_REPO_SLUG == "my-repo" + assert env.VERCEL_GIT_REPO_OWNER == "my-org" + assert env.VERCEL_GIT_COMMIT_REF == "main" + assert env.VERCEL_GIT_COMMIT_SHA == "abc123" + + +class TestIpAddress: + """Test ip_address function.""" + + def test_ip_address_from_request_object(self, mock_env_clear): + """Test extracting IP from request-like object.""" + from vercel.functions import ip_address + + class MockHeaders: + def get(self, name): + if name == "x-real-ip": + return "203.0.113.42" + return None + + class MockRequest: + headers = MockHeaders() + + ip = ip_address(MockRequest()) + assert ip == "203.0.113.42" + + def test_ip_address_from_headers_object(self, mock_env_clear): + """Test extracting IP from headers-like object.""" + from vercel.functions import ip_address + + class MockHeaders: + def get(self, name): + if name == "x-real-ip": + return "192.168.1.100" + return None + + ip = ip_address(MockHeaders()) + assert ip == "192.168.1.100" + + def test_ip_address_missing(self, mock_env_clear): + """Test IP is None when header missing.""" + from vercel.functions import ip_address + + class MockHeaders: + def get(self, name): + return None + + class MockRequest: + headers = MockHeaders() + + ip = ip_address(MockRequest()) + assert ip is None + + +class TestGeolocation: + """Test geolocation function.""" + + def test_geolocation_full(self, mock_env_clear): + """Test extracting full geolocation data.""" + from vercel.functions import geolocation + + headers_data = { + "x-vercel-ip-city": "San%20Francisco", # URL encoded + "x-vercel-ip-country": "US", + "x-vercel-ip-country-region": "CA", + "x-vercel-ip-latitude": "37.7749", + "x-vercel-ip-longitude": "-122.4194", + "x-vercel-ip-postal-code": "94103", + "x-vercel-id": "iad1::12345", + } + + class MockHeaders: + def get(self, name): + return headers_data.get(name) + + class MockRequest: + headers = MockHeaders() + + geo = geolocation(MockRequest()) + + assert geo["city"] == "San Francisco" # Decoded + assert geo["country"] == "US" + assert geo["countryRegion"] == "CA" + assert geo["latitude"] == "37.7749" + assert geo["longitude"] == "-122.4194" + assert geo["postalCode"] == "94103" + assert geo["region"] == "iad1" + + def test_geolocation_flag_generation(self, mock_env_clear): + """Test country flag emoji generation.""" + from vercel.functions import geolocation + + class MockHeaders: + def get(self, name): + if name == "x-vercel-ip-country": + return "US" + return None + + class MockRequest: + headers = MockHeaders() + + geo = geolocation(MockRequest()) + + # US flag emoji + assert geo["flag"] is not None + # The flag should be the US flag emoji (two regional indicator symbols) + + def test_geolocation_empty(self, mock_env_clear): + """Test geolocation with no headers.""" + from vercel.functions import geolocation + + class MockHeaders: + def get(self, name): + return None + + class MockRequest: + headers = MockHeaders() + + geo = geolocation(MockRequest()) + + assert geo["city"] is None + assert geo["country"] is None + assert geo["region"] == "dev1" # Default when no request ID + + def test_geolocation_region_from_request_id(self, mock_env_clear): + """Test region extraction from request ID.""" + from vercel.functions import geolocation + + class MockHeaders: + def get(self, name): + if name == "x-vercel-id": + return "sfo1::request-123" + return None + + class MockRequest: + headers = MockHeaders() + + geo = geolocation(MockRequest()) + assert geo["region"] == "sfo1" + + +class TestSetGetHeaders: + """Test header context management.""" + + def test_set_and_get_headers(self, mock_env_clear): + """Test setting and getting headers in context.""" + from vercel.functions import get_headers, set_headers + + # Initially None + assert get_headers() is None + + # Set headers + set_headers({"Content-Type": "application/json", "X-Custom": "value"}) + + headers = get_headers() + assert headers is not None + assert headers["Content-Type"] == "application/json" + assert headers["X-Custom"] == "value" + + # Clear headers + set_headers(None) + assert get_headers() is None + + def test_headers_overwrite(self, mock_env_clear): + """Test that set_headers overwrites previous values.""" + from vercel.functions import get_headers, set_headers + + set_headers({"First": "value1"}) + assert get_headers()["First"] == "value1" + + set_headers({"Second": "value2"}) + headers = get_headers() + assert headers.get("First") is None + assert headers["Second"] == "value2" + + set_headers(None) + + +class TestRuntimeCacheExport: + """Test that cache classes are properly exported from functions.""" + + def test_get_cache_export(self, mock_env_clear): + """Test get_cache is accessible from functions module.""" + from vercel.functions import RuntimeCache, get_cache + + cache = get_cache() + assert isinstance(cache, RuntimeCache) + + def test_async_runtime_cache_export(self, mock_env_clear): + """Test AsyncRuntimeCache is accessible from functions module.""" + from vercel.functions import AsyncRuntimeCache + + cache = AsyncRuntimeCache() + assert cache is not None + + +class TestEnvImmutability: + """Test Env dataclass immutability.""" + + def test_env_is_frozen(self, mock_env_clear, monkeypatch): + """Test that Env instances are immutable.""" + from vercel.functions import get_env + + monkeypatch.setenv("VERCEL", "1") + + env = get_env() + + with pytest.raises(AttributeError): # Frozen dataclass + env.VERCEL = "2" + + +class TestGeoTypedDict: + """Test Geo TypedDict structure.""" + + def test_geo_fields(self, mock_env_clear): + """Test Geo TypedDict has expected fields.""" + from vercel.functions import Geo + + # Create a Geo dict manually + geo: Geo = { + "city": "New York", + "country": "US", + "flag": None, + "region": "iad1", + "countryRegion": "NY", + "latitude": "40.7128", + "longitude": "-74.0060", + "postalCode": "10001", + } + + assert geo["city"] == "New York" + assert geo["country"] == "US" diff --git a/tests/integration/test_http_transport_raw_body.py b/tests/integration/test_http_transport_raw_body.py new file mode 100644 index 0000000..aef18f1 --- /dev/null +++ b/tests/integration/test_http_transport_raw_body.py @@ -0,0 +1,74 @@ +"""Tests for RawBody support in HTTP transports.""" + +import pytest +import respx +from httpx import Response + +from vercel._internal.http import ( + AsyncTransport, + RawBody, + SyncTransport, + create_base_async_client, + create_base_client, +) +from vercel._internal.iter_coroutine import iter_coroutine + + +class TestRawBodySupport: + """Test that RawBody content is passed through transport unchanged.""" + + @pytest.mark.parametrize("stream", [False, True], ids=["non_stream", "stream"]) + @respx.mock + def test_sync_raw_body_iterable(self, stream: bool): + """SyncTransport should forward iterable bodies without JSON encoding.""" + base_url = "https://upload.example.com" + expected = b"chunk-1chunk-2" + + def handler(request): + payload = b"".join(request.stream) + assert payload == expected + return Response(200, json={"ok": True}) + + route = respx.post(f"{base_url}/upload").mock(side_effect=handler) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + try: + body = RawBody(iter([b"chunk-1", b"chunk-2"])) + response = iter_coroutine(transport.send("POST", "/upload", body=body, stream=stream)) + assert response.status_code == 200 + assert response.read() == b'{"ok":true}' + assert route.called + finally: + transport.close() + + @pytest.mark.parametrize("stream", [False, True], ids=["non_stream", "stream"]) + @respx.mock + @pytest.mark.asyncio + async def test_async_raw_body_async_iterable(self, stream: bool): + """AsyncTransport should forward async iterable bodies without JSON encoding.""" + base_url = "https://upload.example.com" + expected = b"part-apart-b" + + async def chunks(): + yield b"part-a" + yield b"part-b" + + async def handler(request): + body = b"" + async for chunk in request.stream: + body += chunk + assert body == expected + return Response(200, json={"ok": True}) + + route = respx.post(f"{base_url}/upload").mock(side_effect=handler) + + client = create_base_async_client(timeout=30.0, base_url=base_url) + transport = AsyncTransport(client) + try: + response = await transport.send("POST", "upload", body=RawBody(chunks()), stream=stream) + assert response.status_code == 200 + assert await response.aread() == b'{"ok":true}' + assert route.called + finally: + await transport.aclose() diff --git a/tests/integration/test_http_url_construction.py b/tests/integration/test_http_url_construction.py new file mode 100644 index 0000000..1797d24 --- /dev/null +++ b/tests/integration/test_http_url_construction.py @@ -0,0 +1,260 @@ +"""Tests for HTTP URL construction using httpx base_url. + +These tests verify that the transport layer normalizes URLs consistently: +- base_url is always normalized to end with a trailing slash +- path is always normalized to not start with a leading slash +- This ensures consistent URL joining: final_url = base_url + path + +Users can pass base_url with or without trailing slash, and paths with or +without leading slash - the result will be the same. +""" + +import pytest +import respx +from httpx import Response + +from vercel._internal.http import ( + AsyncTransport, + JSONBody, + SyncTransport, + create_base_async_client, + create_base_client, +) +from vercel._internal.iter_coroutine import iter_coroutine + + +class TestUrlNormalization: + """Test that URL normalization produces consistent results regardless of input format.""" + + @pytest.mark.parametrize( + "base_url,path,expected_url", + [ + # All four combinations of trailing/leading slash produce the same result + ("https://api.example.com", "/v1/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/", "/v1/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com", "v1/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/", "v1/projects", "https://api.example.com/v1/projects"), + # Same for paths with multiple segments + ("https://api.example.com", "/v1/cache/keys", "https://api.example.com/v1/cache/keys"), + ("https://api.example.com/", "v1/cache/keys", "https://api.example.com/v1/cache/keys"), + # Base URL with path segment - all variants work + ("https://api.example.com/v1", "/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/v1/", "/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/v1", "projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/v1/", "projects", "https://api.example.com/v1/projects"), + ], + ids=[ + "no_trailing_with_leading", + "trailing_with_leading", + "no_trailing_no_leading", + "trailing_no_leading", + "multi_segment_no_trailing_with_leading", + "multi_segment_trailing_no_leading", + "base_with_path_no_trailing_with_leading", + "base_with_path_trailing_with_leading", + "base_with_path_no_trailing_no_leading", + "base_with_path_trailing_no_leading", + ], + ) + @respx.mock + def test_sync_normalization_consistency(self, base_url: str, path: str, expected_url: str): + """Test that SyncTransport normalizes URLs consistently.""" + route = respx.get(expected_url).mock(return_value=Response(200, json={"ok": True})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("GET", path)) + assert response.status_code == 200 + assert route.called + finally: + transport.close() + + @pytest.mark.parametrize( + "base_url,path,expected_url", + [ + # All four combinations produce the same result + ("https://api.example.com", "/v1/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/", "/v1/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com", "v1/projects", "https://api.example.com/v1/projects"), + ("https://api.example.com/", "v1/projects", "https://api.example.com/v1/projects"), + ], + ids=[ + "no_trailing_with_leading", + "trailing_with_leading", + "no_trailing_no_leading", + "trailing_no_leading", + ], + ) + @respx.mock + @pytest.mark.asyncio + async def test_async_normalization_consistency( + self, base_url: str, path: str, expected_url: str + ): + """Test that AsyncTransport normalizes URLs consistently.""" + route = respx.get(expected_url).mock(return_value=Response(200, json={"ok": True})) + + client = create_base_async_client(timeout=30.0, base_url=base_url) + transport = AsyncTransport(client) + + try: + response = await transport.send("GET", path) + assert response.status_code == 200 + assert route.called + finally: + await transport.aclose() + + +class TestEdgeCases: + """Test edge cases for URL construction.""" + + @pytest.mark.parametrize( + "base_url,path,expected_url", + [ + # Empty path + ("https://api.example.com/", "", "https://api.example.com/"), + ("https://api.example.com", "", "https://api.example.com/"), + # Root path + ("https://api.example.com/", "/", "https://api.example.com/"), + ("https://api.example.com", "/", "https://api.example.com/"), + # Nested base URL + ( + "https://api.example.com/v1/cache/", + "my-key", + "https://api.example.com/v1/cache/my-key", + ), + ], + ids=[ + "empty_path_trailing_base", + "empty_path_no_trailing_base", + "root_path_trailing_base", + "root_path_no_trailing_base", + "nested_base_url", + ], + ) + @respx.mock + def test_edge_cases(self, base_url: str, path: str, expected_url: str): + """Test edge cases for URL construction.""" + route = respx.get(expected_url).mock(return_value=Response(200, json={"ok": True})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("GET", path)) + assert response.status_code == 200 + assert route.called + finally: + transport.close() + + +class TestCacheUrlPatterns: + """Test URL patterns specifically used by the cache module. + + The cache module uses paths without leading slashes (e.g., cache keys like + "my-key" and actions like "revalidate"). These tests verify that pattern + works correctly with base_url ending in a trailing slash. + """ + + @respx.mock + def test_cache_get_key(self): + """Test GET request for a cache key.""" + base_url = "https://cache.example.com/v1/" + key = "user-123-profile" + expected = f"{base_url}{key}" + + route = respx.get(expected).mock(return_value=Response(200, json={"data": "cached"})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("GET", key)) + assert response.status_code == 200 + assert route.called + finally: + transport.close() + + @respx.mock + def test_cache_set_key(self): + """Test POST request to set a cache key.""" + base_url = "https://cache.example.com/v1/" + key = "user-456-settings" + expected = f"{base_url}{key}" + + route = respx.post(expected).mock(return_value=Response(200, json={"ok": True})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("POST", key, body=JSONBody({"value": "test"}))) + assert response.status_code == 200 + assert route.called + finally: + transport.close() + + @respx.mock + def test_cache_revalidate(self): + """Test POST request to revalidate endpoint.""" + base_url = "https://cache.example.com/v1/" + expected = f"{base_url}revalidate" + + route = respx.post(expected).mock(return_value=Response(200, json={"ok": True})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("POST", "revalidate", params={"tags": "foo"})) + assert response.status_code == 200 + assert route.called + finally: + transport.close() + + +class TestApiUrlPatterns: + """Test URL patterns used by the API clients (projects, deployments, etc.). + + The API clients use paths with leading slashes (e.g., "/v10/projects"). + """ + + @respx.mock + def test_projects_list(self): + """Test GET request for projects list.""" + base_url = "https://api.vercel.com" + path = "/v10/projects" + expected = f"{base_url}{path}" + + route = respx.get(expected).mock(return_value=Response(200, json={"projects": []})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("GET", path)) + assert response.status_code == 200 + assert route.called + finally: + transport.close() + + @respx.mock + def test_project_by_id(self): + """Test GET request for a specific project.""" + base_url = "https://api.vercel.com" + project_id = "prj_abc123" + path = f"/v9/projects/{project_id}" + expected = f"{base_url}{path}" + + route = respx.get(expected).mock(return_value=Response(200, json={"id": project_id})) + + client = create_base_client(timeout=30.0, base_url=base_url) + transport = SyncTransport(client) + + try: + response = iter_coroutine(transport.send("GET", path)) + assert response.status_code == 200 + assert route.called + finally: + transport.close() diff --git a/tests/integration/test_oidc_sync_async.py b/tests/integration/test_oidc_sync_async.py new file mode 100644 index 0000000..702c607 --- /dev/null +++ b/tests/integration/test_oidc_sync_async.py @@ -0,0 +1,257 @@ +"""Integration tests for Vercel OIDC module. + +Tests token retrieval from context, environment, and JWT payload decoding. +""" + +import base64 +import json + +import pytest + + +class TestOidcTokenFromContext: + """Test OIDC token retrieval from request context.""" + + def test_get_token_from_header_context(self, mock_env_clear): + """Test getting OIDC token from request headers.""" + from vercel.cache.context import set_headers + from vercel.oidc import get_vercel_oidc_token_sync + + # Set headers via context + set_headers({"x-vercel-oidc-token": "test_oidc_token_from_header"}) + + try: + token = get_vercel_oidc_token_sync() + assert token == "test_oidc_token_from_header" + finally: + # Clear headers + set_headers(None) + + def test_get_token_from_env_variable(self, mock_env_clear, monkeypatch): + """Test getting OIDC token from environment variable.""" + from vercel.oidc import get_vercel_oidc_token_sync + + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "test_oidc_token_from_env") + + token = get_vercel_oidc_token_sync() + assert token == "test_oidc_token_from_env" + + def test_missing_token_raises_error(self, mock_env_clear): + """Test that missing token raises VercelOidcTokenError.""" + from vercel.oidc import VercelOidcTokenError, get_vercel_oidc_token_sync + + with pytest.raises(VercelOidcTokenError) as exc_info: + get_vercel_oidc_token_sync() + + assert "x-vercel-oidc-token" in str(exc_info.value) + + def test_header_takes_precedence_over_env(self, mock_env_clear, monkeypatch): + """Test that header takes precedence over environment variable.""" + from vercel.cache.context import set_headers + from vercel.oidc import get_vercel_oidc_token_sync + + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "env_token") + set_headers({"x-vercel-oidc-token": "header_token"}) + + try: + token = get_vercel_oidc_token_sync() + assert token == "header_token" + finally: + set_headers(None) + + +class TestDecodeOidcPayload: + """Test JWT payload decoding.""" + + def test_decode_valid_jwt_payload(self, mock_env_clear): + """Test decoding a valid JWT token payload.""" + from vercel.oidc import decode_oidc_payload + + # Create a valid JWT-like token with a base64url encoded payload + payload = { + "sub": "test_subject", + "aud": "vercel", + "iss": "https://oidc.vercel.com", + "exp": 9999999999, + "project_id": "prj_test123", + "owner_id": "team_test456", + } + # Base64url encode the payload + payload_json = json.dumps(payload) + payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") + + # Create a mock JWT (header.payload.signature) + mock_token = f"header.{payload_b64}.signature" + + decoded = decode_oidc_payload(mock_token) + + assert decoded["sub"] == "test_subject" + assert decoded["aud"] == "vercel" + assert decoded["iss"] == "https://oidc.vercel.com" + assert decoded["project_id"] == "prj_test123" + assert decoded["owner_id"] == "team_test456" + + def test_decode_payload_with_padding(self, mock_env_clear): + """Test decoding payload that requires base64 padding.""" + from vercel.oidc import decode_oidc_payload + + payload = {"short": "data"} + payload_json = json.dumps(payload) + payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") + + mock_token = f"header.{payload_b64}.signature" + + decoded = decode_oidc_payload(mock_token) + assert decoded["short"] == "data" + + +class TestGetCredentials: + """Test get_credentials function.""" + + def test_get_credentials_with_explicit_values(self, mock_env_clear): + """Test getting credentials with explicitly provided values.""" + from vercel.oidc import Credentials, get_credentials + + creds = get_credentials( + token="explicit_token", + project_id="prj_explicit", + team_id="team_explicit", + ) + + assert isinstance(creds, Credentials) + assert creds.token == "explicit_token" + assert creds.project_id == "prj_explicit" + assert creds.team_id == "team_explicit" + + def test_get_credentials_from_env(self, mock_env_clear, monkeypatch): + """Test getting credentials from environment variables.""" + from vercel.oidc import get_credentials + + monkeypatch.setenv("VERCEL_TOKEN", "env_token") + monkeypatch.setenv("VERCEL_PROJECT_ID", "prj_from_env") + monkeypatch.setenv("VERCEL_TEAM_ID", "team_from_env") + + creds = get_credentials() + + assert creds.token == "env_token" + assert creds.project_id == "prj_from_env" + assert creds.team_id == "team_from_env" + + def test_get_credentials_missing_raises_error(self, mock_env_clear): + """Test that missing credentials raises RuntimeError.""" + from vercel.oidc import get_credentials + + with pytest.raises(RuntimeError) as exc_info: + get_credentials() + + assert "Missing credentials" in str(exc_info.value) + + def test_get_credentials_partial_explicit_with_env(self, mock_env_clear, monkeypatch): + """Test partial explicit credentials filled from env.""" + from vercel.oidc import get_credentials + + monkeypatch.setenv("VERCEL_TOKEN", "env_token") + monkeypatch.setenv("VERCEL_PROJECT_ID", "prj_from_env") + monkeypatch.setenv("VERCEL_TEAM_ID", "team_from_env") + + # Only provide token explicitly + creds = get_credentials(token="explicit_only_token") + + # Should use explicit token but env for project_id and team_id + assert creds.token == "explicit_only_token" + assert creds.project_id == "prj_from_env" + assert creds.team_id == "team_from_env" + + def test_get_credentials_from_oidc_token(self, mock_env_clear, monkeypatch): + """Test getting credentials from OIDC token payload.""" + from vercel.oidc import get_credentials + + # Create a valid OIDC token with embedded project and team info + payload = { + "project_id": "prj_oidc_embedded", + "owner_id": "team_oidc_embedded", + "exp": 9999999999, + } + payload_json = json.dumps(payload) + payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") + oidc_token = f"header.{payload_b64}.signature" + + monkeypatch.setenv("VERCEL_OIDC_TOKEN", oidc_token) + + creds = get_credentials() + + assert creds.token == oidc_token + assert creds.project_id == "prj_oidc_embedded" + assert creds.team_id == "team_oidc_embedded" + + def test_get_credentials_oidc_with_env_project_team(self, mock_env_clear, monkeypatch): + """Test OIDC token with project/team from env vars.""" + from vercel.oidc import get_credentials + + # OIDC token without embedded project/team + payload = {"sub": "test", "exp": 9999999999} + payload_json = json.dumps(payload) + payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") + oidc_token = f"header.{payload_b64}.signature" + + monkeypatch.setenv("VERCEL_OIDC_TOKEN", oidc_token) + monkeypatch.setenv("VERCEL_PROJECT_ID", "prj_env_override") + monkeypatch.setenv("VERCEL_TEAM_ID", "team_env_override") + + creds = get_credentials() + + assert creds.token == oidc_token + assert creds.project_id == "prj_env_override" + assert creds.team_id == "team_env_override" + + +class TestCredentialsDataclass: + """Test Credentials dataclass.""" + + def test_credentials_fields(self, mock_env_clear): + """Test Credentials dataclass has expected fields.""" + from vercel.oidc import Credentials + + creds = Credentials( + token="test_token", + project_id="prj_test", + team_id="team_test", + ) + + assert creds.token == "test_token" + assert creds.project_id == "prj_test" + assert creds.team_id == "team_test" + + def test_credentials_equality(self, mock_env_clear): + """Test Credentials equality comparison.""" + from vercel.oidc import Credentials + + creds1 = Credentials(token="tok", project_id="prj", team_id="team") + creds2 = Credentials(token="tok", project_id="prj", team_id="team") + creds3 = Credentials(token="different", project_id="prj", team_id="team") + + assert creds1 == creds2 + assert creds1 != creds3 + + +class TestVercelOidcTokenError: + """Test VercelOidcTokenError exception.""" + + def test_error_message(self, mock_env_clear): + """Test error message formatting.""" + from vercel.oidc import VercelOidcTokenError + + error = VercelOidcTokenError("Test error message") + assert str(error) == "Test error message" + assert error.cause is None + + def test_error_with_cause(self, mock_env_clear): + """Test error with cause exception.""" + from vercel.oidc import VercelOidcTokenError + + cause = ValueError("Original error") + error = VercelOidcTokenError("Wrapped error", cause) + + assert "Wrapped error" in str(error) + assert "Original error" in str(error) + assert error.cause is cause diff --git a/tests/integration/test_projects_sync_async.py b/tests/integration/test_projects_sync_async.py index 7f99a4e..de80af7 100644 --- a/tests/integration/test_projects_sync_async.py +++ b/tests/integration/test_projects_sync_async.py @@ -7,7 +7,7 @@ """ import os -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -180,11 +180,16 @@ def mock_projects_response(self): def test_get_projects_sync(self, mock_token, mock_projects_response): """Test sync get_projects function with comprehensive output validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_projects_response - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client result = get_projects(token=mock_token) @@ -253,27 +258,27 @@ def test_get_projects_sync(self, mock_token, mock_projects_response): assert pagination["prev"] is None or isinstance(pagination["prev"], int) # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "GET" # method - assert call_args[0][1] == "/v10/projects" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token - - # Validate no body for GET request (json parameter not passed when None) - assert "json" not in call_args[1] or call_args[1]["json"] is None + assert "v10/projects" in call_args[0][1] # url contains path (leading / stripped) @pytest.mark.asyncio async def test_get_projects_async(self, mock_token, mock_projects_response): """Test async get_projects_async function with request validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_projects_response - mock_request.return_value = mock_response + + # Create a mock client that properly supports async context manager + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client result = await get_projects_async(token=mock_token) @@ -281,26 +286,25 @@ async def test_get_projects_async(self, mock_token, mock_projects_response): assert result == mock_projects_response # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "GET" # method - assert call_args[0][1] == "/v10/projects" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token - - # Validate no body for GET request (json parameter not passed when None) - assert "json" not in call_args[1] or call_args[1]["json"] is None + assert "v10/projects" in call_args[0][1] # url contains path (leading / stripped) def test_create_project_sync(self, mock_token, mock_project_data): """Test sync create_project function with comprehensive output validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 201 mock_response.json.return_value = mock_project_data - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client project_body = {"name": "test-project", "framework": "nextjs"} result = create_project(body=project_body, token=mock_token) @@ -352,15 +356,12 @@ def test_create_project_sync(self, mock_token, mock_project_data): assert result["updatedAt"] == 1640995200000 # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "POST" # method - assert call_args[0][1] == "/v11/projects" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token + assert "v11/projects" in call_args[0][1] # url contains path (leading / stripped) # Validate request body assert call_args[1]["json"] == project_body @@ -368,11 +369,16 @@ def test_create_project_sync(self, mock_token, mock_project_data): @pytest.mark.asyncio async def test_create_project_async(self, mock_token, mock_project_data): """Test async create_project_async function with request validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 201 mock_response.json.return_value = mock_project_data - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client project_body = {"name": "test-project", "framework": "nextjs"} result = await create_project_async(body=project_body, token=mock_token) @@ -381,26 +387,28 @@ async def test_create_project_async(self, mock_token, mock_project_data): assert result == mock_project_data # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "POST" # method - assert call_args[0][1] == "/v11/projects" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token + assert "v11/projects" in call_args[0][1] # url contains path (leading / stripped) # Validate request body assert call_args[1]["json"] == project_body def test_update_project_sync(self, mock_token, mock_project_data): """Test sync update_project function with request validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_project_data - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client project_id = "test_project_123" update_body = {"framework": "nextjs", "buildCommand": "npm run build"} @@ -410,15 +418,14 @@ def test_update_project_sync(self, mock_token, mock_project_data): assert result == mock_project_data # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "PATCH" # method - assert call_args[0][1] == f"/v9/projects/{project_id}" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token + assert ( + f"v9/projects/{project_id}" in call_args[0][1] + ) # url contains path (leading / stripped) # Validate request body assert call_args[1]["json"] == update_body @@ -426,11 +433,16 @@ def test_update_project_sync(self, mock_token, mock_project_data): @pytest.mark.asyncio async def test_update_project_async(self, mock_token, mock_project_data): """Test async update_project_async function with request validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_project_data - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client project_id = "test_project_123" update_body = {"framework": "nextjs", "buildCommand": "npm run build"} @@ -440,25 +452,29 @@ async def test_update_project_async(self, mock_token, mock_project_data): assert result == mock_project_data # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "PATCH" # method - assert call_args[0][1] == f"/v9/projects/{project_id}" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token + assert ( + f"v9/projects/{project_id}" in call_args[0][1] + ) # url contains path (leading / stripped) # Validate request body assert call_args[1]["json"] == update_body def test_delete_project_sync(self, mock_token): """Test sync delete_project function with request validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 204 - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client project_id = "test_project_123" result = delete_project(project_id, token=mock_token) @@ -467,26 +483,27 @@ def test_delete_project_sync(self, mock_token): assert result is None # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "DELETE" # method - assert call_args[0][1] == f"/v9/projects/{project_id}" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token - - # Validate no body for DELETE request (json parameter not passed when None) - assert "json" not in call_args[1] or call_args[1]["json"] is None + assert ( + f"v9/projects/{project_id}" in call_args[0][1] + ) # url contains path (leading / stripped) @pytest.mark.asyncio async def test_delete_project_async(self, mock_token): """Test async delete_project_async function with request validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 204 - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client project_id = "test_project_123" result = await delete_project_async(project_id, token=mock_token) @@ -495,65 +512,76 @@ async def test_delete_project_async(self, mock_token): assert result is None # Validate request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + mock_client.request.assert_called_once() + call_args = mock_client.request.call_args # Validate HTTP method and path assert call_args[0][0] == "DELETE" # method - assert call_args[0][1] == f"/v9/projects/{project_id}" # path - - # Validate token parameter - assert call_args[1]["token"] == mock_token - - # Validate no body for DELETE request (json parameter not passed when None) - assert "json" not in call_args[1] or call_args[1]["json"] is None + assert ( + f"v9/projects/{project_id}" in call_args[0][1] + ) # url contains path (leading / stripped) def test_get_projects_with_team_id_sync(self, mock_token): """Test sync get_projects with team_id parameter validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client team_id = "team_123" get_projects(token=mock_token, team_id=team_id) # Validate request was made with correct params - call_args = mock_request.call_args + call_args = mock_client.request.call_args params = call_args[1]["params"] assert params["teamId"] == team_id @pytest.mark.asyncio async def test_get_projects_with_team_id_async(self, mock_token): """Test async get_projects_async with team_id parameter validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client team_id = "team_123" await get_projects_async(token=mock_token, team_id=team_id) # Validate request was made with correct params - call_args = mock_request.call_args + call_args = mock_client.request.call_args params = call_args[1]["params"] assert params["teamId"] == team_id def test_get_projects_with_query_params_sync(self, mock_token): """Test sync get_projects with query parameters validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client query_params = {"search": "test", "limit": 10} get_projects(token=mock_token, query=query_params) # Validate request was made with correct params - call_args = mock_request.call_args + call_args = mock_client.request.call_args params = call_args[1]["params"] assert params["search"] == "test" assert params["limit"] == 10 @@ -561,64 +589,84 @@ def test_get_projects_with_query_params_sync(self, mock_token): @pytest.mark.asyncio async def test_get_projects_with_query_params_async(self, mock_token): """Test async get_projects_async with query parameters validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client query_params = {"search": "test", "limit": 10} await get_projects_async(token=mock_token, query=query_params) # Validate request was made with correct params - call_args = mock_request.call_args + call_args = mock_client.request.call_args params = call_args[1]["params"] assert params["search"] == "test" assert params["limit"] == 10 def test_create_project_with_team_id_sync(self, mock_token, mock_project_data): """Test sync create_project with team_id parameter validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 201 mock_response.json.return_value = mock_project_data - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client project_body = {"name": "test-project"} team_id = "team_123" create_project(body=project_body, token=mock_token, team_id=team_id) # Validate request was made with correct params - call_args = mock_request.call_args + call_args = mock_client.request.call_args params = call_args[1]["params"] assert params["teamId"] == team_id @pytest.mark.asyncio async def test_create_project_with_team_id_async(self, mock_token, mock_project_data): """Test async create_project_async with team_id parameter validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 201 mock_response.json.return_value = mock_project_data - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client project_body = {"name": "test-project"} team_id = "team_123" await create_project_async(body=project_body, token=mock_token, team_id=team_id) # Validate request was made with correct params - call_args = mock_request.call_args + call_args = mock_client.request.call_args params = call_args[1]["params"] assert params["teamId"] == team_id def test_error_handling_sync(self, mock_token): """Test sync error handling with comprehensive output validation.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 400 mock_response.reason_phrase = "Bad Request" mock_response.json.return_value = {"error": "Invalid request"} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client # Validate that the correct exception is raised with pytest.raises(RuntimeError) as exc_info: @@ -632,17 +680,22 @@ def test_error_handling_sync(self, mock_token): assert "Invalid request" in error_message # Validate that the request was still made - mock_request.assert_called_once() + mock_client.request.assert_called_once() @pytest.mark.asyncio async def test_error_handling_async(self, mock_token): """Test async error handling with detailed validation.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 400 mock_response.reason_phrase = "Bad Request" mock_response.json.return_value = {"error": "Invalid request"} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client with pytest.raises(RuntimeError, match="Failed to get projects"): await get_projects_async(token=mock_token) @@ -662,63 +715,81 @@ async def test_missing_token_error_async(self): def test_timeout_parameter_sync(self, mock_token): """Test sync functions accept timeout parameter.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client get_projects(token=mock_token, timeout=120.0) - # Validate timeout was passed correctly - call_args = mock_request.call_args - assert call_args[1]["timeout"] == 120.0 + # Validate that httpx.Client was called (timeout is passed to constructor) + mock_client_class.assert_called_once() @pytest.mark.asyncio async def test_timeout_parameter_async(self, mock_token): """Test async functions accept timeout parameter.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client await get_projects_async(token=mock_token, timeout=120.0) - # Validate timeout was passed correctly - call_args = mock_request.call_args - assert call_args[1]["timeout"] == 120.0 + # Validate that httpx.AsyncClient was called (timeout is passed to constructor) + mock_client_class.assert_called_once() def test_base_url_parameter_sync(self, mock_token): """Test sync functions accept base_url parameter.""" - with patch("vercel.projects.projects._request") as mock_request: + with patch("vercel.projects.projects.httpx.Client") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_class.return_value = mock_client custom_base_url = "https://custom-api.example.com" get_projects(token=mock_token, base_url=custom_base_url) - # Validate base_url was passed correctly - call_args = mock_request.call_args - assert call_args[1]["base_url"] == custom_base_url + # Validate request URL uses custom base URL + call_args = mock_client.request.call_args + assert call_args[0][1].startswith(custom_base_url) @pytest.mark.asyncio async def test_base_url_parameter_async(self, mock_token): """Test async functions accept base_url parameter.""" - with patch("vercel.projects.projects._request_async") as mock_request: + with patch("vercel.projects.projects.httpx.AsyncClient") as mock_client_class: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"projects": []} - mock_request.return_value = mock_response + + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client custom_base_url = "https://custom-api.example.com" await get_projects_async(token=mock_token, base_url=custom_base_url) - # Validate base_url was passed correctly - call_args = mock_request.call_args - assert call_args[1]["base_url"] == custom_base_url + # Validate request URL uses custom base URL + call_args = mock_client.request.call_args + assert call_args[0][1].startswith(custom_base_url) class TestConsistency: @@ -732,21 +803,30 @@ async def test_sync_async_consistency(self): "pagination": {"count": 1}, } - # Mock both sync and async request functions + # Mock both sync and async clients with ( - patch("vercel.projects.projects._request") as mock_sync_request, - patch("vercel.projects.projects._request_async") as mock_async_request, + patch("vercel.projects.projects.httpx.Client") as mock_sync_client_class, + patch("vercel.projects.projects.httpx.AsyncClient") as mock_async_client_class, ): - # Setup mock responses + # Setup sync mock mock_sync_response = MagicMock() mock_sync_response.status_code = 200 mock_sync_response.json.return_value = mock_response_data - mock_sync_request.return_value = mock_sync_response + mock_sync_client = MagicMock() + mock_sync_client.request.return_value = mock_sync_response + mock_sync_client.__enter__ = MagicMock(return_value=mock_sync_client) + mock_sync_client.__exit__ = MagicMock(return_value=False) + mock_sync_client_class.return_value = mock_sync_client + # Setup async mock mock_async_response = MagicMock() mock_async_response.status_code = 200 mock_async_response.json.return_value = mock_response_data - mock_async_request.return_value = mock_async_response + mock_async_client = MagicMock() + mock_async_client.request = AsyncMock(return_value=mock_async_response) + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client_class.return_value = mock_async_client # Call both versions sync_result = get_projects(token="test_token") diff --git a/tests/integration/test_sandbox_sync_async.py b/tests/integration/test_sandbox_sync_async.py new file mode 100644 index 0000000..5417b95 --- /dev/null +++ b/tests/integration/test_sandbox_sync_async.py @@ -0,0 +1,797 @@ +"""Integration tests for Vercel Sandbox API using respx mocking. + +Tests both sync and async variants (Sandbox and AsyncSandbox). +""" + +import httpx +import pytest +import respx + +# Base URL for Vercel Sandbox API +SANDBOX_API_BASE = "https://api.vercel.com" + + +class TestSandboxCreate: + """Test sandbox creation operations.""" + + @respx.mock + def test_create_sandbox_sync(self, mock_env_clear, mock_sandbox_create_response): + """Test synchronous sandbox creation.""" + from vercel.sandbox import Sandbox + + # Mock the sandbox creation endpoint + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_create_response, + "routes": [ + { + "port": 3000, + "subdomain": "test-sbx", + "url": "https://test-sbx.vercel.run", + } + ], + }, + ) + ) + + sandbox = Sandbox.create( + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + assert route.called + assert sandbox.sandbox_id == "sbx_test123456" + assert sandbox.status == "running" + assert sandbox.timeout == 300 + + # Cleanup + sandbox.client.close() + + @respx.mock + @pytest.mark.asyncio + async def test_create_sandbox_async(self, mock_env_clear, mock_sandbox_create_response): + """Test asynchronous sandbox creation.""" + from vercel.sandbox import AsyncSandbox + + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_create_response, + "routes": [ + { + "port": 3000, + "subdomain": "test-sbx", + "url": "https://test-sbx.vercel.run", + } + ], + }, + ) + ) + + sandbox = await AsyncSandbox.create( + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + assert route.called + assert sandbox.sandbox_id == "sbx_test123456" + assert sandbox.status == "running" + + # Cleanup + await sandbox.client.aclose() + + @respx.mock + def test_create_sandbox_with_options(self, mock_env_clear, mock_sandbox_create_response): + """Test sandbox creation with all options.""" + from vercel.sandbox import Sandbox + + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_create_response, + "routes": [], + }, + ) + ) + + sandbox = Sandbox.create( + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ports=[3000, 8080], + timeout=600000, + runtime="nodejs20.x", + ) + + assert route.called + # Verify request body + import json + + body = json.loads(route.calls.last.request.content) + assert body["ports"] == [3000, 8080] + assert body["timeout"] == 600000 + assert body["runtime"] == "nodejs20.x" + + sandbox.client.close() + + +class TestSandboxGet: + """Test sandbox get operations.""" + + @respx.mock + def test_get_sandbox_sync(self, mock_env_clear, mock_sandbox_get_response): + """Test synchronous get sandbox by ID.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + route = respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + assert route.called + assert sandbox.sandbox_id == sandbox_id + assert sandbox.status == "running" + + sandbox.client.close() + + @respx.mock + @pytest.mark.asyncio + async def test_get_sandbox_async(self, mock_env_clear, mock_sandbox_get_response): + """Test asynchronous get sandbox by ID.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + route = respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + assert route.called + assert sandbox.sandbox_id == sandbox_id + + await sandbox.client.aclose() + + +class TestSandboxRunCommand: + """Test sandbox command execution.""" + + @respx.mock + def test_run_command_sync( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_command_response + ): + """Test synchronous command execution.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock run command + cmd_id = mock_sandbox_command_response["commandId"] + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/cmd").mock( + return_value=httpx.Response( + 200, + json={ + "command": { + "id": cmd_id, + "name": "echo", + "args": ["Hello, World!"], + "cwd": "/app", + "sandboxId": sandbox_id, + "exitCode": None, + "startedAt": 1705320600000, + } + }, + ) + ) + + # Mock wait for command (with query param wait=true) + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/cmd/{cmd_id}").mock( + return_value=httpx.Response( + 200, + json={ + "command": { + "id": cmd_id, + "name": "echo", + "args": ["Hello, World!"], + "cwd": "/app", + "sandboxId": sandbox_id, + "exitCode": 0, + "startedAt": 1705320600000, + "stdout": "Hello, World!\n", + "stderr": "", + } + }, + ) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + result = sandbox.run_command("echo", ["Hello, World!"]) + + assert result.exit_code == 0 + # Note: stdout() is a method that fetches logs from API, not tested here + + sandbox.client.close() + + @respx.mock + @pytest.mark.asyncio + async def test_run_command_async( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_command_response + ): + """Test asynchronous command execution.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + cmd_id = mock_sandbox_command_response["commandId"] + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock run command + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/cmd").mock( + return_value=httpx.Response( + 200, + json={ + "command": { + "id": cmd_id, + "name": "echo", + "args": ["Hello, World!"], + "cwd": "/app", + "sandboxId": sandbox_id, + "exitCode": None, + "startedAt": 1705320600000, + } + }, + ) + ) + + # Mock wait for command + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/cmd/{cmd_id}").mock( + return_value=httpx.Response( + 200, + json={ + "command": { + "id": cmd_id, + "name": "echo", + "args": ["Hello, World!"], + "cwd": "/app", + "sandboxId": sandbox_id, + "exitCode": 0, + "startedAt": 1705320600000, + "stdout": "Hello, World!\n", + "stderr": "", + } + }, + ) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + result = await sandbox.run_command("echo", ["Hello, World!"]) + + assert result.exit_code == 0 + + await sandbox.client.aclose() + + +class TestSandboxFileOperations: + """Test sandbox file operations.""" + + @respx.mock + def test_read_file_sync( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_read_file_content + ): + """Test synchronous file read.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock read file + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/fs/read").mock( + return_value=httpx.Response(200, content=mock_sandbox_read_file_content) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + content = sandbox.read_file("/etc/hosts") + + assert content is not None + assert content == mock_sandbox_read_file_content + + sandbox.client.close() + + @respx.mock + def test_read_file_not_found(self, mock_env_clear, mock_sandbox_get_response): + """Test file read returns None for non-existent file.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock read file - 404 + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/fs/read").mock( + return_value=httpx.Response(404, json={"error": {"code": "not_found"}}) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + content = sandbox.read_file("/nonexistent/file") + + assert content is None + + sandbox.client.close() + + @respx.mock + @pytest.mark.asyncio + async def test_read_file_async( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_read_file_content + ): + """Test asynchronous file read.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock read file + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/fs/read").mock( + return_value=httpx.Response(200, content=mock_sandbox_read_file_content) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + content = await sandbox.read_file("/etc/hosts") + + assert content is not None + + await sandbox.client.aclose() + + @respx.mock + def test_mk_dir_sync(self, mock_env_clear, mock_sandbox_get_response): + """Test synchronous directory creation.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock mkdir + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/fs/mkdir").mock( + return_value=httpx.Response(200, json={}) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + sandbox.mk_dir("/app/data") + + assert route.called + import json + + body = json.loads(route.calls.last.request.content) + assert body["path"] == "/app/data" + + sandbox.client.close() + + +class TestSandboxStop: + """Test sandbox stop operations.""" + + @respx.mock + def test_stop_sync(self, mock_env_clear, mock_sandbox_get_response): + """Test synchronous sandbox stop.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock stop + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/stop").mock( + return_value=httpx.Response(200, json={"sandbox": stopped_response}) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + sandbox.stop() + + assert route.called + + sandbox.client.close() + + @respx.mock + @pytest.mark.asyncio + async def test_stop_async(self, mock_env_clear, mock_sandbox_get_response): + """Test asynchronous sandbox stop.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock stop + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/stop").mock( + return_value=httpx.Response(200, json={"sandbox": stopped_response}) + ) + + sandbox = await AsyncSandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + await sandbox.stop() + + assert route.called + + await sandbox.client.aclose() + + +class TestSandboxContextManager: + """Test sandbox context manager behavior.""" + + @respx.mock + def test_context_manager_sync(self, mock_env_clear, mock_sandbox_create_response): + """Test sync context manager stops sandbox on exit.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock create + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_create_response, + "routes": [], + }, + ) + ) + + # Mock stop + stopped_response = dict(mock_sandbox_create_response) + stopped_response["status"] = "stopped" + stop_route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/stop").mock( + return_value=httpx.Response(200, json={"sandbox": stopped_response}) + ) + + with Sandbox.create( + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) as sandbox: + assert sandbox.status == "running" + + # Stop should have been called + assert stop_route.called + + @respx.mock + @pytest.mark.asyncio + async def test_context_manager_async(self, mock_env_clear, mock_sandbox_create_response): + """Test async context manager stops sandbox on exit.""" + from vercel.sandbox import AsyncSandbox + + sandbox_id = "sbx_test123456" + + # Mock create + respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_create_response, + "routes": [], + }, + ) + ) + + # Mock stop + stopped_response = dict(mock_sandbox_create_response) + stopped_response["status"] = "stopped" + stop_route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/stop").mock( + return_value=httpx.Response(200, json={"sandbox": stopped_response}) + ) + + async with await AsyncSandbox.create( + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) as sandbox: + assert sandbox.status == "running" + + # Stop should have been called + assert stop_route.called + + +class TestSandboxSnapshot: + """Test sandbox snapshot operations.""" + + @respx.mock + def test_create_snapshot_sync( + self, mock_env_clear, mock_sandbox_get_response, mock_sandbox_snapshot_response + ): + """Test synchronous snapshot creation.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock create snapshot + stopped_response = dict(mock_sandbox_get_response) + stopped_response["status"] = "stopped" + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/snapshot").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": stopped_response, + "snapshot": mock_sandbox_snapshot_response, + }, + ) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + snapshot = sandbox.snapshot() + + assert route.called + assert snapshot.snapshot_id == mock_sandbox_snapshot_response["id"] + # Sandbox should be stopped after snapshot + assert sandbox.status == "stopped" + + sandbox.client.close() + + +class TestSandboxExtendTimeout: + """Test sandbox timeout extension.""" + + @respx.mock + def test_extend_timeout_sync(self, mock_env_clear, mock_sandbox_get_response): + """Test synchronous timeout extension.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [], + }, + ) + ) + + # Mock extend timeout + extended_response = dict(mock_sandbox_get_response) + extended_response["timeout"] = 600 + route = respx.post(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}/extend-timeout").mock( + return_value=httpx.Response(200, json={"sandbox": extended_response}) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + sandbox.extend_timeout(300000) # 5 minutes + + assert route.called + import json + + body = json.loads(route.calls.last.request.content) + assert body["duration"] == 300000 + assert sandbox.timeout == 600 + + sandbox.client.close() + + +class TestSandboxDomain: + """Test sandbox domain resolution.""" + + @respx.mock + def test_domain_resolution(self, mock_env_clear, mock_sandbox_get_response): + """Test domain resolution for ports.""" + from vercel.sandbox import Sandbox + + sandbox_id = "sbx_test123456" + + # Mock get sandbox with routes + respx.get(f"{SANDBOX_API_BASE}/v1/sandboxes/{sandbox_id}").mock( + return_value=httpx.Response( + 200, + json={ + "sandbox": mock_sandbox_get_response, + "routes": [ + { + "port": 3000, + "subdomain": "app-3000", + "url": "https://app-3000.vercel.run", + }, + { + "port": 8080, + "subdomain": "api-8080", + "url": "https://api-8080.vercel.run", + }, + ], + }, + ) + ) + + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token="test_token", + team_id="team_test123", + project_id="prj_test123", + ) + + # Test URL resolution + assert sandbox.domain(3000) == "https://app-3000.vercel.run" + assert sandbox.domain(8080) == "https://api-8080.vercel.run" + + # Test invalid port + with pytest.raises(ValueError, match="No route for port"): + sandbox.domain(9999) + + sandbox.client.close() diff --git a/tests/live/__init__.py b/tests/live/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/live/conftest.py b/tests/live/conftest.py new file mode 100644 index 0000000..5eb762a --- /dev/null +++ b/tests/live/conftest.py @@ -0,0 +1,192 @@ +"""Fixtures for live API tests. + +These tests require real API credentials set via environment variables: +- VERCEL_TOKEN: Vercel API token +- VERCEL_TEAM_ID: Vercel team ID +- BLOB_READ_WRITE_TOKEN: Blob storage read/write token +""" + +import os +import time +import uuid +from collections.abc import Generator +from typing import Any + +import pytest + + +def has_vercel_credentials() -> bool: + """Check if Vercel API credentials are available.""" + return bool(os.getenv("VERCEL_TOKEN") and os.getenv("VERCEL_TEAM_ID")) + + +def has_blob_credentials() -> bool: + """Check if Blob storage credentials are available.""" + return bool(os.getenv("BLOB_READ_WRITE_TOKEN")) + + +def has_sandbox_credentials() -> bool: + """Check if Sandbox credentials are available.""" + return has_vercel_credentials() + + +# Skip markers for live tests +requires_vercel_credentials = pytest.mark.skipif( + not has_vercel_credentials(), + reason="Requires VERCEL_TOKEN and VERCEL_TEAM_ID environment variables", +) + +requires_blob_credentials = pytest.mark.skipif( + not has_blob_credentials(), + reason="Requires BLOB_READ_WRITE_TOKEN environment variable", +) + +requires_sandbox_credentials = pytest.mark.skipif( + not has_sandbox_credentials(), + reason="Requires VERCEL_TOKEN and VERCEL_TEAM_ID environment variables for sandbox", +) + + +@pytest.fixture +def vercel_token() -> str: + """Get Vercel API token from environment.""" + token = os.getenv("VERCEL_TOKEN") + if not token: + pytest.skip("VERCEL_TOKEN environment variable not set") + return token + + +@pytest.fixture +def vercel_team_id() -> str: + """Get Vercel team ID from environment.""" + team_id = os.getenv("VERCEL_TEAM_ID") + if not team_id: + pytest.skip("VERCEL_TEAM_ID environment variable not set") + return team_id + + +@pytest.fixture +def blob_token() -> str: + """Get Blob storage token from environment.""" + token = os.getenv("BLOB_READ_WRITE_TOKEN") + if not token: + pytest.skip("BLOB_READ_WRITE_TOKEN environment variable not set") + return token + + +@pytest.fixture +def unique_test_name() -> str: + """Generate a unique test resource name with timestamp. + + Format: vercel-py-test-{timestamp}-{uuid} + """ + timestamp = int(time.time()) + unique_id = uuid.uuid4().hex[:8] + return f"vercel-py-test-{timestamp}-{unique_id}" + + +@pytest.fixture +def unique_blob_path() -> str: + """Generate a unique blob path for testing. + + Format: test/{timestamp}-{uuid}/file.txt + """ + timestamp = int(time.time()) + unique_id = uuid.uuid4().hex[:8] + return f"test/{timestamp}-{unique_id}/file.txt" + + +class CleanupRegistry: + """Registry for tracking resources that need cleanup after tests.""" + + def __init__(self) -> None: + self._cleanups: list[tuple[str, Any]] = [] + + def register(self, resource_type: str, resource_id: Any) -> None: + """Register a resource for cleanup. + + Args: + resource_type: Type of resource (e.g., "blob", "project", "sandbox") + resource_id: Identifier for the resource + """ + self._cleanups.append((resource_type, resource_id)) + + def get_resources(self, resource_type: str) -> list[Any]: + """Get all registered resources of a specific type.""" + return [rid for rtype, rid in self._cleanups if rtype == resource_type] + + def clear(self) -> None: + """Clear all registered resources.""" + self._cleanups.clear() + + +@pytest.fixture +def cleanup_registry() -> Generator[CleanupRegistry, None, None]: + """Fixture providing a cleanup registry for tracking test resources. + + Usage: + def test_create_resource(cleanup_registry, blob_token): + result = put("test.txt", b"data", token=blob_token) + cleanup_registry.register("blob", result.url) + # Test continues... + # Cleanup happens automatically after test + """ + registry = CleanupRegistry() + yield registry + + # Cleanup blob resources + blob_urls = registry.get_resources("blob") + if blob_urls: + try: + from vercel.blob import delete + + blob_token = os.getenv("BLOB_READ_WRITE_TOKEN") + if blob_token: + for url in blob_urls: + try: + delete(url, token=blob_token) + except Exception: + pass # Best effort cleanup + except ImportError: + pass + + # Cleanup project resources + project_ids = registry.get_resources("project") + if project_ids: + try: + from vercel.projects import delete_project + + vercel_token = os.getenv("VERCEL_TOKEN") + team_id = os.getenv("VERCEL_TEAM_ID") + if vercel_token and team_id: + for project_id in project_ids: + try: + delete_project(project_id, token=vercel_token, team_id=team_id) + except Exception: + pass # Best effort cleanup + except ImportError: + pass + + # Cleanup sandbox resources + sandbox_ids = registry.get_resources("sandbox") + if sandbox_ids: + try: + from vercel.sandbox import Sandbox + + vercel_token = os.getenv("VERCEL_TOKEN") + team_id = os.getenv("VERCEL_TEAM_ID") + if vercel_token and team_id: + for sandbox_id in sandbox_ids: + try: + sandbox = Sandbox.get( + sandbox_id=sandbox_id, + token=vercel_token, + team_id=team_id, + ) + sandbox.stop() + except Exception: + pass # Best effort cleanup + except ImportError: + pass + + registry.clear() diff --git a/tests/live/test_blob_live.py b/tests/live/test_blob_live.py new file mode 100644 index 0000000..ba7677a --- /dev/null +++ b/tests/live/test_blob_live.py @@ -0,0 +1,213 @@ +"""Live API tests for Vercel Blob storage. + +These tests make real API calls and require BLOB_READ_WRITE_TOKEN environment variable. +Run with: pytest tests/live/test_blob_live.py -v +""" + +import pytest + +from .conftest import requires_blob_credentials + + +@requires_blob_credentials +@pytest.mark.live +class TestBlobLive: + """Live tests for Blob API operations.""" + + def test_put_and_delete_lifecycle(self, blob_token, unique_blob_path, cleanup_registry): + """Test complete blob put -> head -> delete lifecycle.""" + from vercel.blob import delete, head, put + from vercel.blob.errors import BlobNotFoundError + + # Put a blob + result = put( + unique_blob_path, + b"Hello, World! This is a test blob.", + token=blob_token, + ) + cleanup_registry.register("blob", result.url) + + # Verify the result + assert result.url is not None + assert result.pathname == unique_blob_path + assert result.content_type is not None + + # Head to get metadata + meta = head(result.url, token=blob_token) + assert meta.size > 0 + assert meta.pathname == unique_blob_path + + # Delete the blob + delete(result.url, token=blob_token) + + # Verify deletion - head should raise BlobNotFoundError + with pytest.raises(BlobNotFoundError): + head(result.url, token=blob_token) + + @pytest.mark.asyncio + async def test_put_and_delete_async(self, blob_token, unique_blob_path, cleanup_registry): + """Test async blob put -> head -> delete lifecycle.""" + from vercel.blob import delete_async, head_async, put_async + from vercel.blob.errors import BlobNotFoundError + + # Put a blob + result = await put_async( + unique_blob_path, + b"Hello, World! This is an async test blob.", + token=blob_token, + ) + cleanup_registry.register("blob", result.url) + + # Verify the result + assert result.url is not None + assert result.pathname == unique_blob_path + + # Head to get metadata + meta = await head_async(result.url, token=blob_token) + assert meta.size > 0 + + # Delete the blob + await delete_async(result.url, token=blob_token) + + # Verify deletion + with pytest.raises(BlobNotFoundError): + await head_async(result.url, token=blob_token) + + def test_list_objects(self, blob_token, unique_blob_path, cleanup_registry): + """Test listing blobs with prefix filter.""" + from vercel.blob import delete, list_objects, put + + # Create a blob with a unique prefix + prefix = unique_blob_path.rsplit("/", 1)[0] + "/" + blob_path = f"{prefix}list-test.txt" + + result = put(blob_path, b"list test content", token=blob_token) + cleanup_registry.register("blob", result.url) + + # List objects with the prefix + listing = list_objects(prefix=prefix, token=blob_token) + + # Should find at least our blob + assert len(listing.blobs) >= 1 + found = any(b.pathname == blob_path for b in listing.blobs) + assert found, f"Expected to find {blob_path} in listing" + + # Cleanup + delete(result.url, token=blob_token) + + def test_blob_client_class(self, blob_token, unique_blob_path, cleanup_registry): + """Test BlobClient class-based interface.""" + from vercel.blob import BlobClient + from vercel.blob.errors import BlobNotFoundError + + client = BlobClient(token=blob_token) + + # Put using client + result = client.put(unique_blob_path, b"Client test content") + cleanup_registry.register("blob", result.url) + + # Head using client + meta = client.head(result.url) + assert meta.size > 0 + + # List using client + listing = client.list_objects(limit=5) + assert listing.blobs is not None + + # Delete using client + client.delete(result.url) + + # Verify deletion + with pytest.raises(BlobNotFoundError): + client.head(result.url) + + @pytest.mark.asyncio + async def test_async_blob_client_class(self, blob_token, unique_blob_path, cleanup_registry): + """Test AsyncBlobClient class-based interface.""" + from vercel.blob import AsyncBlobClient + from vercel.blob.errors import BlobNotFoundError + + client = AsyncBlobClient(token=blob_token) + + # Put using client + result = await client.put(unique_blob_path, b"Async client test content") + cleanup_registry.register("blob", result.url) + + # Head using client + meta = await client.head(result.url) + assert meta.size > 0 + + # Delete using client + await client.delete(result.url) + + # Verify deletion + with pytest.raises(BlobNotFoundError): + await client.head(result.url) + + def test_copy_operation(self, blob_token, unique_blob_path, cleanup_registry): + """Test server-side copy operation.""" + from vercel.blob import copy, delete, head, put + + # Create source blob + source_path = unique_blob_path + source_result = put(source_path, b"Source content for copy", token=blob_token) + cleanup_registry.register("blob", source_result.url) + + # Copy to new destination + dest_path = source_path.replace(".txt", "-copy.txt") + copy_result = copy(source_result.url, dest_path, token=blob_token) + cleanup_registry.register("blob", copy_result.url) + + # Verify copy exists + copy_meta = head(copy_result.url, token=blob_token) + assert copy_meta.size > 0 + assert copy_meta.pathname == dest_path + + # Cleanup + delete(source_result.url, token=blob_token) + delete(copy_result.url, token=blob_token) + + def test_create_folder(self, blob_token, unique_test_name, cleanup_registry): + """Test folder creation.""" + from vercel.blob import create_folder, delete, list_objects + + folder_path = f"test-folders/{unique_test_name}" + + result = create_folder(folder_path, token=blob_token) + cleanup_registry.register("blob", result.url) + + # Verify folder was created + assert result.pathname.endswith("/") + + # List should show the folder + listing = list_objects(prefix=f"test-folders/{unique_test_name}", token=blob_token) + # Verify the folder appears in the listing (as a folder or blob depending on mode) + folder_urls = [b.url for b in listing.blobs] + list(listing.folders or []) + assert result.url in folder_urls, f"Created folder {result.url} not found in listing" + + # Cleanup + delete(result.url, token=blob_token) + + def test_iter_objects(self, blob_token, unique_blob_path, cleanup_registry): + """Test blob iteration.""" + from vercel.blob import delete, iter_objects, put + + # Create multiple blobs + prefix = unique_blob_path.rsplit("/", 1)[0] + "/" + urls = [] + + for i in range(3): + blob_path = f"{prefix}iter-test-{i}.txt" + result = put(blob_path, f"Content {i}".encode(), token=blob_token) + cleanup_registry.register("blob", result.url) + urls.append(result.url) + + # Iterate over objects + items = list(iter_objects(prefix=prefix, token=blob_token)) + + # Should find our blobs + assert len(items) >= 3 + + # Cleanup + for url in urls: + delete(url, token=blob_token) diff --git a/tests/live/test_projects_live.py b/tests/live/test_projects_live.py new file mode 100644 index 0000000..dfffa0d --- /dev/null +++ b/tests/live/test_projects_live.py @@ -0,0 +1,374 @@ +""" +Live API tests for Vercel Projects module. + +These tests make actual API calls to Vercel and validate the real responses. +They require VERCEL_TOKEN and VERCEL_TEAM_ID environment variables. + +Run with: pytest tests/live/test_projects_live.py -v +""" + +import time + +import pytest + +# Import the actual functions (not mocked) +from vercel.projects import create_project, delete_project, get_projects, update_project +from vercel.projects.projects import ( + create_project_async, + delete_project_async, + get_projects_async, +) + +from .conftest import requires_vercel_credentials + + +@requires_vercel_credentials +@pytest.mark.live +class TestProjectsLive: + """Test suite for Projects API using real Vercel API calls.""" + + def test_get_projects_real_api(self, vercel_token, vercel_team_id): + """Test get_projects with real API and validate actual response structure.""" + result = get_projects(token=vercel_token, team_id=vercel_team_id) + + # Validate response is a dict + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + + # Validate top-level structure + assert "projects" in result, "Response missing 'projects' key" + assert "pagination" in result, "Response missing 'pagination' key" + + # Validate projects array + projects = result["projects"] + assert isinstance(projects, list), f"Expected list, got {type(projects)}" + + # Validate first project structure if projects exist + if len(projects) > 0: + project = projects[0] + assert isinstance(project, dict), f"Expected dict, got {type(project)}" + + # Validate core required fields exist + required_fields = ["id", "name", "accountId", "createdAt", "updatedAt"] + for field in required_fields: + assert field in project, f"Missing required field: {field}" + + # Validate data types + assert isinstance(project["id"], str), f"Expected string, got {type(project['id'])}" + assert isinstance(project["name"], str), f"Expected string, got {type(project['name'])}" + assert isinstance(project["accountId"], str), ( + f"Expected string, got {type(project['accountId'])}" + ) + assert isinstance(project["createdAt"], int), ( + f"Expected int, got {type(project['createdAt'])}" + ) + assert isinstance(project["updatedAt"], int), ( + f"Expected int, got {type(project['updatedAt'])}" + ) + + # Validate ID formats + assert project["id"].startswith("prj_"), ( + f"Project ID should start with 'prj_', got: {project['id']}" + ) + assert project["accountId"].startswith("team_"), ( + f"Account ID should start with 'team_', got: {project['accountId']}" + ) + + # Validate timestamps are reasonable (after 2020) + assert project["createdAt"] > 1577836800000, ( + f"Created timestamp too old: {project['createdAt']}" + ) + assert project["updatedAt"] > 1577836800000, ( + f"Updated timestamp too old: {project['updatedAt']}" + ) + + # Validate pagination structure + pagination = result["pagination"] + assert isinstance(pagination, dict), f"Expected dict, got {type(pagination)}" + assert "count" in pagination, "Pagination missing 'count'" + assert "next" in pagination, "Pagination missing 'next'" + assert "prev" in pagination, "Pagination missing 'prev'" + assert isinstance(pagination["count"], int), ( + f"Expected int, got {type(pagination['count'])}" + ) + + def test_create_project_real_api( + self, vercel_token, vercel_team_id, unique_test_name, cleanup_registry + ): + """Test create_project with real API and validate actual response.""" + project_body = {"name": unique_test_name, "framework": "nextjs"} + + result = create_project(body=project_body, token=vercel_token, team_id=vercel_team_id) + project_id = result["id"] + cleanup_registry.register("project", project_id) + + try: + # Validate response structure + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + + # Validate core fields + assert "id" in result, "Response missing 'id'" + assert "name" in result, "Response missing 'name'" + assert "accountId" in result, "Response missing 'accountId'" + assert "createdAt" in result, "Response missing 'createdAt'" + assert "updatedAt" in result, "Response missing 'updatedAt'" + + # Validate data types + assert isinstance(result["id"], str), f"Expected string, got {type(result['id'])}" + assert isinstance(result["name"], str), f"Expected string, got {type(result['name'])}" + assert isinstance(result["accountId"], str), ( + f"Expected string, got {type(result['accountId'])}" + ) + assert isinstance(result["createdAt"], int), ( + f"Expected int, got {type(result['createdAt'])}" + ) + assert isinstance(result["updatedAt"], int), ( + f"Expected int, got {type(result['updatedAt'])}" + ) + + # Validate values match what we sent + assert result["name"] == unique_test_name, ( + f"Expected {unique_test_name}, got {result['name']}" + ) + assert result["accountId"] == vercel_team_id, ( + f"Expected {vercel_team_id}, got {result['accountId']}" + ) + + # Validate ID format + assert result["id"].startswith("prj_"), ( + f"Project ID should start with 'prj_', got: {result['id']}" + ) + + # Validate timestamps are recent (within last minute) + current_time = int(time.time() * 1000) + assert result["createdAt"] > current_time - 60000, ( + f"Created timestamp too old: {result['createdAt']}" + ) + assert result["updatedAt"] > current_time - 60000, ( + f"Updated timestamp too old: {result['updatedAt']}" + ) + + finally: + # Clean up - delete the project + delete_project(project_id, token=vercel_token, team_id=vercel_team_id) + + def test_update_project_real_api( + self, vercel_token, vercel_team_id, unique_test_name, cleanup_registry + ): + """Test update_project with real API and validate actual response.""" + # First create a project + project_body = {"name": unique_test_name, "framework": "nextjs"} + created_project = create_project( + body=project_body, token=vercel_token, team_id=vercel_team_id + ) + project_id = created_project["id"] + cleanup_registry.register("project", project_id) + + try: + # Update the project + update_body = {"framework": "svelte"} + result = update_project( + project_id, body=update_body, token=vercel_token, team_id=vercel_team_id + ) + + # Validate response structure + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + + # Validate core fields + assert "id" in result, "Response missing 'id'" + assert "name" in result, "Response missing 'name'" + assert "accountId" in result, "Response missing 'accountId'" + assert "updatedAt" in result, "Response missing 'updatedAt'" + + # Validate values + assert result["id"] == project_id, f"Expected {project_id}, got {result['id']}" + assert result["name"] == unique_test_name, ( + f"Expected {unique_test_name}, got {result['name']}" + ) + assert result["accountId"] == vercel_team_id, ( + f"Expected {vercel_team_id}, got {result['accountId']}" + ) + + # Validate updatedAt is newer than createdAt + assert result["updatedAt"] >= created_project["createdAt"], ( + "UpdatedAt should be >= createdAt" + ) + + finally: + # Clean up - delete the project + delete_project(project_id, token=vercel_token, team_id=vercel_team_id) + + def test_delete_project_real_api( + self, vercel_token, vercel_team_id, unique_test_name, cleanup_registry + ): + """Test delete_project with real API.""" + # First create a project + project_body = {"name": unique_test_name, "framework": "nextjs"} + created_project = create_project( + body=project_body, token=vercel_token, team_id=vercel_team_id + ) + project_id = created_project["id"] + + # Delete the project (don't register for cleanup since we're deleting immediately) + delete_project(project_id, token=vercel_token, team_id=vercel_team_id) + + # Verify project is actually deleted by checking it's not in the list + try: + projects_list = get_projects(token=vercel_token, team_id=vercel_team_id) + project_ids = [p["id"] for p in projects_list["projects"]] + assert project_id not in project_ids, ( + f"Project {project_id} still exists after deletion" + ) + except Exception as e: + pytest.fail(f"Failed to verify project deletion: {e}") + + @pytest.mark.asyncio + async def test_get_projects_async_real_api(self, vercel_token, vercel_team_id): + """Test get_projects_async with real API and validate actual response.""" + result = await get_projects_async(token=vercel_token, team_id=vercel_team_id) + + # Same validation as sync version + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + assert "projects" in result, "Response missing 'projects' key" + assert "pagination" in result, "Response missing 'pagination' key" + + projects = result["projects"] + assert isinstance(projects, list), f"Expected list, got {type(projects)}" + + # Validate first project structure if projects exist + if len(projects) > 0: + project = projects[0] + assert isinstance(project, dict), f"Expected dict, got {type(project)}" + + # Validate core fields + required_fields = ["id", "name", "accountId", "createdAt", "updatedAt"] + for field in required_fields: + assert field in project, f"Missing required field: {field}" + + # Validate data types + assert isinstance(project["id"], str), f"Expected string, got {type(project['id'])}" + assert isinstance(project["name"], str), f"Expected string, got {type(project['name'])}" + assert isinstance(project["accountId"], str), ( + f"Expected string, got {type(project['accountId'])}" + ) + assert isinstance(project["createdAt"], int), ( + f"Expected int, got {type(project['createdAt'])}" + ) + assert isinstance(project["updatedAt"], int), ( + f"Expected int, got {type(project['updatedAt'])}" + ) + + # Validate ID formats + assert project["id"].startswith("prj_"), ( + f"Project ID should start with 'prj_', got: {project['id']}" + ) + assert project["accountId"].startswith("team_"), ( + f"Account ID should start with 'team_', got: {project['accountId']}" + ) + + @pytest.mark.asyncio + async def test_create_project_async_real_api( + self, vercel_token, vercel_team_id, unique_test_name, cleanup_registry + ): + """Test create_project_async with real API and validate actual response.""" + project_body = {"name": f"{unique_test_name}-async", "framework": "nextjs"} + + result = await create_project_async( + body=project_body, token=vercel_token, team_id=vercel_team_id + ) + cleanup_registry.register("project", result["id"]) + + # Same validation as sync version + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + + # Validate core fields + assert "id" in result, "Response missing 'id'" + assert "name" in result, "Response missing 'name'" + assert "accountId" in result, "Response missing 'accountId'" + assert "createdAt" in result, "Response missing 'createdAt'" + assert "updatedAt" in result, "Response missing 'updatedAt'" + + # Validate data types + assert isinstance(result["id"], str), f"Expected string, got {type(result['id'])}" + assert isinstance(result["name"], str), f"Expected string, got {type(result['name'])}" + assert isinstance(result["accountId"], str), ( + f"Expected string, got {type(result['accountId'])}" + ) + assert isinstance(result["createdAt"], int), ( + f"Expected int, got {type(result['createdAt'])}" + ) + assert isinstance(result["updatedAt"], int), ( + f"Expected int, got {type(result['updatedAt'])}" + ) + + # Validate values + assert result["name"] == f"{unique_test_name}-async", ( + f"Expected {unique_test_name}-async, got {result['name']}" + ) + assert result["accountId"] == vercel_team_id, ( + f"Expected {vercel_team_id}, got {result['accountId']}" + ) + + # Validate ID format + assert result["id"].startswith("prj_"), ( + f"Project ID should start with 'prj_', got: {result['id']}" + ) + + # Clean up - delete the project + await delete_project_async(result["id"], token=vercel_token, team_id=vercel_team_id) + + def test_error_handling_real_api(self, vercel_token, vercel_team_id): + """Test error handling with real API.""" + # Test with invalid project ID + with pytest.raises(RuntimeError) as exc_info: + delete_project("prj_invalid", token=vercel_token, team_id=vercel_team_id) + + error_message = str(exc_info.value) + assert "Failed to delete project" in error_message, ( + f"Expected 'Failed to delete project' in: {error_message}" + ) + + def test_full_crud_workflow_real_api( + self, vercel_token, vercel_team_id, unique_test_name, cleanup_registry + ): + """Test complete CRUD workflow with real API.""" + project_body = {"name": unique_test_name, "framework": "nextjs"} + project_id = None + + try: + # CREATE + created = create_project(body=project_body, token=vercel_token, team_id=vercel_team_id) + project_id = created["id"] + cleanup_registry.register("project", project_id) + + # READ - verify project exists in list + projects = get_projects(token=vercel_token, team_id=vercel_team_id) + project_names = [p["name"] for p in projects["projects"]] + assert unique_test_name in project_names, ( + f"Project {unique_test_name} not found in list" + ) + + # UPDATE + update_body = {"framework": "svelte"} + updated = update_project( + project_id, body=update_body, token=vercel_token, team_id=vercel_team_id + ) + assert updated["id"] == project_id, "Project ID changed after update" + + # DELETE + delete_project(project_id, token=vercel_token, team_id=vercel_team_id) + + # VERIFY DELETION - project should not be in list anymore + projects_after_delete = get_projects(token=vercel_token, team_id=vercel_team_id) + project_names_after = [p["name"] for p in projects_after_delete["projects"]] + assert unique_test_name not in project_names_after, ( + f"Project {unique_test_name} still exists after deletion" + ) + + except Exception as e: + # Clean up on error + if project_id: + try: + delete_project(project_id, token=vercel_token, team_id=vercel_team_id) + except Exception: + pass + raise e diff --git a/tests/live/test_sandbox_live.py b/tests/live/test_sandbox_live.py new file mode 100644 index 0000000..aa41d35 --- /dev/null +++ b/tests/live/test_sandbox_live.py @@ -0,0 +1,311 @@ +"""Live API tests for Vercel Sandbox. + +These tests make real API calls and require VERCEL_TOKEN and VERCEL_TEAM_ID environment variables. +Run with: pytest tests/live/test_sandbox_live.py -v +""" + +import time + +import pytest + +from .conftest import requires_sandbox_credentials + + +def wait_for_sandbox_running(sandbox, timeout: float = 30.0, poll_interval: float = 0.5): + """Wait for sandbox to reach 'running' status. + + Args: + sandbox: The sandbox instance to wait for. + timeout: Maximum time to wait in seconds. + poll_interval: Time between status checks in seconds. + + Raises: + TimeoutError: If sandbox doesn't reach 'running' status within timeout. + """ + start = time.time() + while time.time() - start < timeout: + if sandbox.status == "running": + return + # Re-fetch sandbox to get updated status + from vercel.sandbox import Sandbox + + updated = Sandbox.get( + sandbox_id=sandbox.sandbox_id, + token=sandbox.client._token, + team_id=sandbox.client._team_id, + ) + sandbox.sandbox = updated.sandbox + updated.client.close() + if sandbox.status == "running": + return + time.sleep(poll_interval) + raise TimeoutError(f"Sandbox did not reach 'running' status within {timeout}s") + + +async def wait_for_sandbox_running_async( + sandbox, timeout: float = 30.0, poll_interval: float = 0.5 +): + """Wait for async sandbox to reach 'running' status. + + Args: + sandbox: The async sandbox instance to wait for. + timeout: Maximum time to wait in seconds. + poll_interval: Time between status checks in seconds. + + Raises: + TimeoutError: If sandbox doesn't reach 'running' status within timeout. + """ + import asyncio + + start = time.time() + while time.time() - start < timeout: + if sandbox.status == "running": + return + # Re-fetch sandbox to get updated status + from vercel.sandbox import AsyncSandbox + + updated = await AsyncSandbox.get( + sandbox_id=sandbox.sandbox_id, + token=sandbox.client._token, + team_id=sandbox.client._team_id, + ) + sandbox.sandbox = updated.sandbox + await updated.client.aclose() + if sandbox.status == "running": + return + await asyncio.sleep(poll_interval) + raise TimeoutError(f"Sandbox did not reach 'running' status within {timeout}s") + + +@requires_sandbox_credentials +@pytest.mark.live +class TestSandboxLive: + """Live tests for Sandbox API operations.""" + + def test_create_run_stop_lifecycle(self, vercel_token, vercel_team_id, cleanup_registry): + """Test complete sandbox create -> run command -> stop lifecycle.""" + from vercel.sandbox import Sandbox + + # Create sandbox + sandbox = Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + try: + # Verify creation + assert sandbox.sandbox_id is not None + # Wait for sandbox to be running (may start in 'pending' state) + wait_for_sandbox_running(sandbox) + assert sandbox.status == "running" + + # Run a simple command + result = sandbox.run_command("echo", ["Hello from sandbox"]) + + assert result.exit_code == 0 + assert "Hello from sandbox" in result.stdout() + + # Stop the sandbox + sandbox.stop() + finally: + # Ensure cleanup + try: + sandbox.stop() + except Exception: + # Sandbox may already be stopped or unreachable + pass + sandbox.client.close() + + @pytest.mark.asyncio + async def test_async_sandbox_lifecycle(self, vercel_token, vercel_team_id, cleanup_registry): + """Test async sandbox create -> run command -> stop lifecycle.""" + from vercel.sandbox import AsyncSandbox + + # Create sandbox using async context manager + async with await AsyncSandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) as sandbox: + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + # Verify creation + assert sandbox.sandbox_id is not None + # Wait for sandbox to be running (may start in 'pending' state) + await wait_for_sandbox_running_async(sandbox) + assert sandbox.status == "running" + + # Run a simple command + result = await sandbox.run_command("echo", ["Async hello"]) + + assert result.exit_code == 0 + assert "Async hello" in await result.stdout() + + # Context manager should have stopped the sandbox + + def test_file_operations(self, vercel_token, vercel_team_id, cleanup_registry): + """Test sandbox file write and read operations.""" + from vercel.sandbox import Sandbox + from vercel.sandbox.models import WriteFile + + sandbox = Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + try: + # Write a file + test_content = "Hello, this is test content!" + sandbox.write_files([WriteFile(path="/tmp/test.txt", content=test_content.encode())]) + + # Read the file back + content = sandbox.read_file("/tmp/test.txt") + + assert content is not None + assert test_content in content.decode() + + # Read a non-existent file + missing = sandbox.read_file("/tmp/nonexistent.txt") + assert missing is None + + finally: + try: + sandbox.stop() + except Exception: + # Sandbox may already be stopped or unreachable + pass + sandbox.client.close() + + def test_run_command_with_env(self, vercel_token, vercel_team_id, cleanup_registry): + """Test running command with environment variables.""" + from vercel.sandbox import Sandbox + + sandbox = Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + try: + # Run command with custom env + result = sandbox.run_command( + "sh", + ["-c", "echo $MY_VAR"], + env={"MY_VAR": "test_value_123"}, + ) + + assert result.exit_code == 0 + assert "test_value_123" in result.stdout() + + finally: + try: + sandbox.stop() + except Exception: + # Sandbox may already be stopped or unreachable + pass + sandbox.client.close() + + def test_run_command_detached(self, vercel_token, vercel_team_id, cleanup_registry): + """Test running a detached command.""" + from vercel.sandbox import Sandbox + + sandbox = Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + try: + # Run a detached command (doesn't wait for completion) + command = sandbox.run_command_detached("sleep", ["1"]) + + assert command.cmd_id is not None + + # Wait for it to complete + finished = command.wait() + assert finished.exit_code == 0 + + finally: + try: + sandbox.stop() + except Exception: + # Sandbox may already be stopped or unreachable + pass + sandbox.client.close() + + def test_context_manager(self, vercel_token, vercel_team_id, cleanup_registry): + """Test sandbox context manager cleanup.""" + from vercel.sandbox import Sandbox + + with Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) as sandbox: + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + # Run a command inside context + result = sandbox.run_command("whoami") + assert result.exit_code == 0 + + # Context manager should have stopped the sandbox + + def test_get_existing_sandbox(self, vercel_token, vercel_team_id, cleanup_registry): + """Test getting an existing sandbox by ID.""" + from vercel.sandbox import Sandbox + + # Create a sandbox + original = Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) + cleanup_registry.register("sandbox", original.sandbox_id) + + try: + # Wait for sandbox to be running before fetching + wait_for_sandbox_running(original) + + # Get the same sandbox by ID + fetched = Sandbox.get( + sandbox_id=original.sandbox_id, + token=vercel_token, + team_id=vercel_team_id, + ) + + assert fetched.sandbox_id == original.sandbox_id + assert fetched.status == "running" + + fetched.client.close() + finally: + try: + original.stop() + except Exception: + # Sandbox may already be stopped or unreachable + pass + original.client.close() + + def test_mk_dir(self, vercel_token, vercel_team_id, cleanup_registry): + """Test creating a directory in the sandbox.""" + from vercel.sandbox import Sandbox + + sandbox = Sandbox.create( + token=vercel_token, + team_id=vercel_team_id, + ) + cleanup_registry.register("sandbox", sandbox.sandbox_id) + + try: + # Create a directory + sandbox.mk_dir("/tmp/test-dir") + + # Verify it exists by running ls + result = sandbox.run_command("ls", ["-la", "/tmp/test-dir"]) + assert result.exit_code == 0 + + finally: + try: + sandbox.stop() + except Exception: + # Sandbox may already be stopped or unreachable + pass + sandbox.client.close() diff --git a/tests/test_blob_ops.py b/tests/test_blob_ops.py index 70d9895..5e0ff6b 100644 --- a/tests/test_blob_ops.py +++ b/tests/test_blob_ops.py @@ -6,18 +6,16 @@ from typing import get_args from unittest.mock import AsyncMock, MagicMock, patch -import httpx import pytest +from vercel._internal.blob import validate_access +from vercel._internal.blob.core import parse_last_modified +from vercel._internal.iter_coroutine import iter_coroutine from vercel.blob.errors import BlobError from vercel.blob.ops import ( - _build_get_result, - _parse_last_modified, - _resolve_blob_url, download_file, download_file_async, ) -from vercel.blob.utils import validate_access # Token format: vercel_blob_rw_{storeId}_... # extract_store_id_from_token splits on "_" and returns index 3 @@ -26,49 +24,15 @@ # --------------------------------------------------------------------------- -# _resolve_blob_url — pure logic, no mocking -# --------------------------------------------------------------------------- -class TestResolveBlobUrl: - def test_url_input_returns_same_url_and_pathname(self): - url = "https://example.com/foo/bar.txt" - result_url, pathname = _resolve_blob_url(url, TOKEN, "public") - assert result_url == url - assert pathname == "foo/bar.txt" - - def test_pathname_public_access(self): - blob_url, pathname = _resolve_blob_url("my/file.txt", TOKEN, "public") - expected = f"https://{STORE_ID}.public.blob.vercel-storage.com/my/file.txt" - assert blob_url == expected - assert pathname == "my/file.txt" - - def test_pathname_private_access(self): - blob_url, pathname = _resolve_blob_url("my/file.txt", TOKEN, "private") - expected = f"https://{STORE_ID}.private.blob.vercel-storage.com/my/file.txt" - assert blob_url == expected - assert pathname == "my/file.txt" - - def test_bad_token_raises_blob_error(self): - with pytest.raises(BlobError): - _resolve_blob_url("my/file.txt", "short_token", "public") - - def test_leading_slash_stripped(self): - blob_url, pathname = _resolve_blob_url("/leading/slash.txt", TOKEN, "public") - assert pathname == "leading/slash.txt" - # No double-slash in the path portion (ignore the scheme "https://") - path_part = blob_url.split("://", 1)[1] - assert "//" not in path_part - - -# --------------------------------------------------------------------------- -# _parse_last_modified — pure logic +# parse_last_modified — pure logic # --------------------------------------------------------------------------- class TestParseLastModified: def test_rfc7231_date(self): - dt = _parse_last_modified("Tue, 15 Nov 1994 08:12:31 GMT") + dt = parse_last_modified("Tue, 15 Nov 1994 08:12:31 GMT") assert dt == datetime(1994, 11, 15, 8, 12, 31, tzinfo=timezone.utc) def test_iso8601_date(self): - dt = _parse_last_modified("2024-01-15T10:30:00+00:00") + dt = parse_last_modified("2024-01-15T10:30:00+00:00") assert dt.year == 2024 assert dt.month == 1 assert dt.day == 15 @@ -77,68 +41,17 @@ def test_iso8601_date(self): def test_none_returns_approx_now(self): before = datetime.now(tz=timezone.utc) - dt = _parse_last_modified(None) + dt = parse_last_modified(None) after = datetime.now(tz=timezone.utc) assert before <= dt <= after def test_invalid_string_returns_approx_now(self): before = datetime.now(tz=timezone.utc) - dt = _parse_last_modified("not-a-date") + dt = parse_last_modified("not-a-date") after = datetime.now(tz=timezone.utc) assert before <= dt <= after -# --------------------------------------------------------------------------- -# _build_get_result — mock httpx.Response -# --------------------------------------------------------------------------- -class TestBuildGetResult: - @staticmethod - def _make_response(status_code: int, headers: dict, content: bytes = b""): - resp = MagicMock(spec=httpx.Response) - resp.status_code = status_code - resp.headers = httpx.Headers(headers) - resp.content = content - return resp - - def test_200_response(self): - resp = self._make_response( - 200, - { - "content-type": "text/plain", - "content-length": "13", - "content-disposition": "inline", - "cache-control": "max-age=300", - "last-modified": "2024-01-15T10:30:00+00:00", - "etag": '"abc"', - }, - content=b"Hello, world!", - ) - result = _build_get_result( - resp, "https://s.public.blob.vercel-storage.com/f.txt", "f.txt" - ) - assert result.status_code == resp.status_code - assert result.content == b"Hello, world!" - assert result.size == 13 - assert result.content_type == "text/plain" - - def test_304_response(self): - resp = self._make_response( - 304, - { - "content-disposition": "inline", - "cache-control": "max-age=300", - "etag": '"abc"', - }, - ) - result = _build_get_result( - resp, "https://s.public.blob.vercel-storage.com/f.txt", "f.txt" - ) - assert result.status_code == 304 - assert result.content == b"" - assert result.size is None - assert result.content_type is None - - # --------------------------------------------------------------------------- # validate_access — pure logic # --------------------------------------------------------------------------- @@ -155,113 +68,89 @@ def test_invalid_raises_blob_error(self): # --------------------------------------------------------------------------- -# download_file (sync) — mock httpx.Client +# download_file (sync) — wrapper delegation # --------------------------------------------------------------------------- class TestDownloadFile: - @staticmethod - def _mock_sync_download(chunk_data: bytes): - """Return a mock httpx.Client whose stream() yields *chunk_data*.""" - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.headers = httpx.Headers( - {"Content-Length": str(len(chunk_data))} - ) - mock_resp.iter_bytes.return_value = iter([chunk_data]) - mock_resp.raise_for_status = MagicMock() - mock_resp.__enter__ = MagicMock(return_value=mock_resp) - mock_resp.__exit__ = MagicMock(return_value=False) - - mock_client = MagicMock() - mock_client.stream.return_value = mock_resp - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - return mock_client - - def test_pathname_constructs_url_and_writes_file(self, tmp_path): + def test_pathname_delegates_to_core_client(self, tmp_path): dest = tmp_path / "downloaded.txt" - chunk_data = b"file content here" - mock_client = self._mock_sync_download(chunk_data) + mock_core_client = MagicMock() + mock_core_client.download_file = AsyncMock(return_value=str(dest)) - with patch("vercel.blob.ops.httpx.Client", return_value=mock_client): - result = download_file( - "my/file.txt", str(dest), token=TOKEN, access="public" - ) + def _run(operation): + return iter_coroutine(operation(mock_core_client)) + + with patch("vercel.blob.ops._run_sync_blob_operation", side_effect=_run): + result = download_file("my/file.txt", str(dest), token=TOKEN, access="public") assert result == str(dest) - assert dest.read_bytes() == chunk_data - # URL was constructed from pathname — no head() call needed - mock_client.stream.assert_called_once() - url_arg = mock_client.stream.call_args[0][1] - assert STORE_ID in url_arg + mock_core_client.download_file.assert_awaited_once_with( + "my/file.txt", + str(dest), + access="public", + token=TOKEN, + timeout=None, + overwrite=True, + create_parents=True, + progress=None, + ) - def test_private_access_sends_auth_header(self, tmp_path): + def test_private_access_passes_access_to_core_client(self, tmp_path): dest = tmp_path / "private.txt" - mock_client = self._mock_sync_download(b"secret") + mock_core_client = MagicMock() + mock_core_client.download_file = AsyncMock(return_value=str(dest)) - with patch("vercel.blob.ops.httpx.Client", return_value=mock_client): - download_file( - "my/secret.txt", str(dest), token=TOKEN, access="private" - ) + def _run(operation): + return iter_coroutine(operation(mock_core_client)) + + with patch("vercel.blob.ops._run_sync_blob_operation", side_effect=_run): + download_file("my/secret.txt", str(dest), token=TOKEN, access="private") - headers = mock_client.stream.call_args[1]["headers"] - assert headers["authorization"] == f"Bearer {TOKEN}" + kwargs = mock_core_client.download_file.await_args.kwargs + assert kwargs["access"] == "private" + assert kwargs["token"] == TOKEN # --------------------------------------------------------------------------- -# download_file_async — mock httpx.AsyncClient +# download_file_async — wrapper delegation # --------------------------------------------------------------------------- class TestDownloadFileAsync: - @staticmethod - def _mock_async_download(chunk_data: bytes): - """Return a mock httpx.AsyncClient whose stream() yields *chunk_data*.""" - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.headers = httpx.Headers( - {"Content-Length": str(len(chunk_data))} - ) - mock_resp.raise_for_status = MagicMock() - - async def aiter_bytes(): - yield chunk_data - - mock_resp.aiter_bytes = aiter_bytes - mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) - mock_resp.__aexit__ = AsyncMock(return_value=False) - - mock_client = MagicMock() - mock_client.stream.return_value = mock_resp - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - return mock_client - - async def test_pathname_constructs_url_and_writes_file(self, tmp_path): + async def test_pathname_delegates_to_core_client(self, tmp_path): dest = tmp_path / "downloaded_async.txt" - chunk_data = b"async file content" - mock_client = self._mock_async_download(chunk_data) + mock_core_client = MagicMock() + mock_core_client.download_file = AsyncMock(return_value=str(dest)) + mock_core_client.__aenter__ = AsyncMock(return_value=mock_core_client) + mock_core_client.__aexit__ = AsyncMock(return_value=False) - with patch( - "vercel.blob.ops.httpx.AsyncClient", return_value=mock_client - ): + with patch("vercel.blob.ops.AsyncBlobOpsClient", return_value=mock_core_client): result = await download_file_async( "my/file.txt", str(dest), token=TOKEN, access="public" ) assert result == str(dest) - assert dest.read_bytes() == chunk_data + mock_core_client.download_file.assert_awaited_once_with( + "my/file.txt", + str(dest), + access="public", + token=TOKEN, + timeout=None, + overwrite=True, + create_parents=True, + progress=None, + ) - async def test_private_access_sends_auth_header(self, tmp_path): + async def test_private_access_passes_access_to_core_client(self, tmp_path): dest = tmp_path / "private_async.txt" - mock_client = self._mock_async_download(b"async secret") + mock_core_client = MagicMock() + mock_core_client.download_file = AsyncMock(return_value=str(dest)) + mock_core_client.__aenter__ = AsyncMock(return_value=mock_core_client) + mock_core_client.__aexit__ = AsyncMock(return_value=False) - with patch( - "vercel.blob.ops.httpx.AsyncClient", return_value=mock_client - ): - await download_file_async( - "my/secret.txt", str(dest), token=TOKEN, access="private" - ) + with patch("vercel.blob.ops.AsyncBlobOpsClient", return_value=mock_core_client): + await download_file_async("my/secret.txt", str(dest), token=TOKEN, access="private") - headers = mock_client.stream.call_args[1]["headers"] - assert headers["authorization"] == f"Bearer {TOKEN}" + kwargs = mock_core_client.download_file.await_args.kwargs + assert kwargs["access"] == "private" + assert kwargs["token"] == TOKEN # --------------------------------------------------------------------------- diff --git a/tests/test_sync_async_parity.py b/tests/test_sync_async_parity.py new file mode 100644 index 0000000..8d268f5 --- /dev/null +++ b/tests/test_sync_async_parity.py @@ -0,0 +1,334 @@ +"""Sync/Async API parity tests. + +Validates that sync and async function pairs have matching signatures +and produce consistent results when given the same inputs. +""" + +import inspect +from collections.abc import Callable +from typing import Any + + +def get_param_names(func: Callable) -> list[str]: + """Extract parameter names from a function signature.""" + sig = inspect.signature(func) + return [ + name + for name, param in sig.parameters.items() + if param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + ] + + +def get_param_defaults(func: Callable) -> dict[str, Any]: + """Extract parameter defaults from a function signature.""" + sig = inspect.signature(func) + return { + name: param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + +def compare_signatures(sync_func: Callable, async_func: Callable) -> list[str]: + """Compare signatures of sync and async functions. + + Returns a list of differences (empty if signatures match). + """ + differences = [] + + sync_params = get_param_names(sync_func) + async_params = get_param_names(async_func) + + if sync_params != async_params: + differences.append(f"Parameter names differ: sync={sync_params}, async={async_params}") + + sync_defaults = get_param_defaults(sync_func) + async_defaults = get_param_defaults(async_func) + + # Check that defaults match for common parameters + for name in set(sync_defaults.keys()) & set(async_defaults.keys()): + if sync_defaults[name] != async_defaults[name]: + differences.append( + f"Default for '{name}' differs: " + f"sync={sync_defaults[name]}, async={async_defaults[name]}" + ) + + return differences + + +class TestBlobSignatureParity: + """Test blob module sync/async signature parity.""" + + def test_put_signatures_match(self): + """Test put and put_async have matching signatures.""" + from vercel.blob import put, put_async + + differences = compare_signatures(put, put_async) + assert not differences, f"Signature differences: {differences}" + + def test_delete_signatures_match(self): + """Test delete and delete_async have matching signatures.""" + from vercel.blob import delete, delete_async + + differences = compare_signatures(delete, delete_async) + assert not differences, f"Signature differences: {differences}" + + def test_head_signatures_match(self): + """Test head and head_async have matching signatures.""" + from vercel.blob import head, head_async + + differences = compare_signatures(head, head_async) + assert not differences, f"Signature differences: {differences}" + + def test_list_objects_signatures_match(self): + """Test list_objects and list_objects_async have matching signatures.""" + from vercel.blob import list_objects, list_objects_async + + differences = compare_signatures(list_objects, list_objects_async) + assert not differences, f"Signature differences: {differences}" + + def test_iter_objects_signatures_match(self): + """Test iter_objects and iter_objects_async have matching signatures.""" + from vercel.blob import iter_objects, iter_objects_async + + differences = compare_signatures(iter_objects, iter_objects_async) + assert not differences, f"Signature differences: {differences}" + + def test_copy_signatures_match(self): + """Test copy and copy_async have matching signatures.""" + from vercel.blob import copy, copy_async + + differences = compare_signatures(copy, copy_async) + assert not differences, f"Signature differences: {differences}" + + def test_create_folder_signatures_match(self): + """Test create_folder and create_folder_async have matching signatures.""" + from vercel.blob import create_folder, create_folder_async + + differences = compare_signatures(create_folder, create_folder_async) + assert not differences, f"Signature differences: {differences}" + + def test_upload_file_signatures_match(self): + """Test upload_file and upload_file_async have matching signatures.""" + from vercel.blob import upload_file, upload_file_async + + differences = compare_signatures(upload_file, upload_file_async) + assert not differences, f"Signature differences: {differences}" + + def test_download_file_signatures_match(self): + """Test download_file and download_file_async have matching signatures.""" + from vercel.blob import download_file, download_file_async + + differences = compare_signatures(download_file, download_file_async) + assert not differences, f"Signature differences: {differences}" + + +class TestBlobMultipartSignatureParity: + """Test blob multipart sync/async signature parity.""" + + def test_create_multipart_upload_signatures_match(self): + """Test create_multipart_upload signatures match.""" + from vercel.blob import create_multipart_upload, create_multipart_upload_async + + differences = compare_signatures(create_multipart_upload, create_multipart_upload_async) + assert not differences, f"Signature differences: {differences}" + + def test_upload_part_signatures_match(self): + """Test upload_part signatures match.""" + from vercel.blob import upload_part, upload_part_async + + differences = compare_signatures(upload_part, upload_part_async) + assert not differences, f"Signature differences: {differences}" + + def test_complete_multipart_upload_signatures_match(self): + """Test complete_multipart_upload signatures match.""" + from vercel.blob import complete_multipart_upload, complete_multipart_upload_async + + differences = compare_signatures(complete_multipart_upload, complete_multipart_upload_async) + assert not differences, f"Signature differences: {differences}" + + def test_create_multipart_uploader_signatures_match(self): + """Test create_multipart_uploader signatures match.""" + from vercel.blob import create_multipart_uploader, create_multipart_uploader_async + + differences = compare_signatures(create_multipart_uploader, create_multipart_uploader_async) + assert not differences, f"Signature differences: {differences}" + + +class TestBlobClientClassParity: + """Test BlobClient and AsyncBlobClient method parity.""" + + def test_client_methods_exist(self): + """Test that both client classes have the same methods.""" + from vercel.blob import AsyncBlobClient, BlobClient + + sync_methods = { + m for m in dir(BlobClient) if not m.startswith("_") and callable(getattr(BlobClient, m)) + } + async_methods = { + m + for m in dir(AsyncBlobClient) + if not m.startswith("_") and callable(getattr(AsyncBlobClient, m)) + } + + # Lifecycle naming intentionally differs by runtime. + assert "close" in sync_methods + assert "aclose" in async_methods + sync_methods.discard("close") + async_methods.discard("aclose") + + assert sync_methods == async_methods, ( + f"Method mismatch: sync_only={sync_methods - async_methods}, " + f"async_only={async_methods - sync_methods}" + ) + + +class TestSandboxClassParity: + """Test Sandbox and AsyncSandbox method parity.""" + + def test_sandbox_methods_exist(self): + """Test that both sandbox classes have equivalent methods.""" + from vercel.sandbox import AsyncSandbox, Sandbox + + # Get public methods (excluding dunder methods) + sync_methods = { + m for m in dir(Sandbox) if not m.startswith("_") and callable(getattr(Sandbox, m)) + } + async_methods = { + m + for m in dir(AsyncSandbox) + if not m.startswith("_") and callable(getattr(AsyncSandbox, m)) + } + + # AsyncSandbox has 'shell' method that Sandbox doesn't have (interactive only) + # So we check that all sync methods exist in async + missing_in_async = sync_methods - async_methods + assert not missing_in_async, f"Methods missing in AsyncSandbox: {missing_in_async}" + + +class TestCacheClassParity: + """Test RuntimeCache and AsyncRuntimeCache method parity.""" + + def test_cache_methods_exist(self): + """Test that cache classes have equivalent methods.""" + from vercel.cache import AsyncRuntimeCache, RuntimeCache + + # Core methods that should exist in both + expected_methods = {"get", "set", "delete", "expire_tag"} + + sync_methods = { + m + for m in dir(RuntimeCache) + if not m.startswith("_") and callable(getattr(RuntimeCache, m)) + } + async_methods = { + m + for m in dir(AsyncRuntimeCache) + if not m.startswith("_") and callable(getattr(AsyncRuntimeCache, m)) + } + + assert expected_methods.issubset(sync_methods), ( + f"Missing sync methods: {expected_methods - sync_methods}" + ) + assert expected_methods.issubset(async_methods), ( + f"Missing async methods: {expected_methods - async_methods}" + ) + + +class TestProjectsSignatureParity: + """Test projects module sync/async signature parity.""" + + def test_get_projects_signatures_match(self): + """Test get_projects and get_projects_async have matching signatures.""" + from vercel.projects import get_projects + from vercel.projects.projects import get_projects_async + + differences = compare_signatures(get_projects, get_projects_async) + assert not differences, f"Signature differences: {differences}" + + def test_create_project_signatures_match(self): + """Test create_project and create_project_async have matching signatures.""" + from vercel.projects import create_project + from vercel.projects.projects import create_project_async + + differences = compare_signatures(create_project, create_project_async) + assert not differences, f"Signature differences: {differences}" + + def test_update_project_signatures_match(self): + """Test update_project and update_project_async have matching signatures.""" + from vercel.projects import update_project + from vercel.projects.projects import update_project_async + + differences = compare_signatures(update_project, update_project_async) + assert not differences, f"Signature differences: {differences}" + + def test_delete_project_signatures_match(self): + """Test delete_project and delete_project_async have matching signatures.""" + from vercel.projects import delete_project + from vercel.projects.projects import delete_project_async + + differences = compare_signatures(delete_project, delete_project_async) + assert not differences, f"Signature differences: {differences}" + + +class TestResultTypeParity: + """Test that sync and async functions return the same result types.""" + + def test_blob_put_returns_same_type(self): + """Test put and put_async return the same result type.""" + from vercel.blob import put, put_async + from vercel.blob.types import PutBlobResult + + sync_annotation = inspect.signature(put).return_annotation + async_annotation = inspect.signature(put_async).return_annotation + + # Sync should return PutBlobResult directly + assert sync_annotation == PutBlobResult or "PutBlobResult" in str(sync_annotation) + + # Async should return Coroutine[..., PutBlobResult] - verify inner type matches + async_str = str(async_annotation) + assert "PutBlobResult" in async_str, f"Async should return PutBlobResult, got {async_str}" + + def test_blob_head_returns_same_type(self): + """Test head and head_async return the same result type.""" + from vercel.blob import head, head_async + from vercel.blob.types import HeadBlobResult + + sync_annotation = inspect.signature(head).return_annotation + async_annotation = inspect.signature(head_async).return_annotation + + # Sync should return HeadBlobResult directly + assert sync_annotation == HeadBlobResult or "HeadBlobResult" in str(sync_annotation) + + # Async should return Coroutine[..., HeadBlobResult] - verify inner type matches + async_str = str(async_annotation) + assert "HeadBlobResult" in async_str, f"Async should return HeadBlobResult, got {async_str}" + + def test_blob_list_returns_same_type(self): + """Test list_objects and list_objects_async return the same result type.""" + from vercel.blob import list_objects, list_objects_async + from vercel.blob.types import ListBlobResult + + sync_annotation = inspect.signature(list_objects).return_annotation + async_annotation = inspect.signature(list_objects_async).return_annotation + + # Sync should return ListBlobResult directly + assert sync_annotation == ListBlobResult or "ListBlobResult" in str(sync_annotation) + + # Async should return Coroutine[..., ListBlobResult] - verify inner type matches + async_str = str(async_annotation) + assert "ListBlobResult" in async_str, f"Async should return ListBlobResult, got {async_str}" + + def test_blob_iter_returns_iterator_types(self): + """Test iter_objects and iter_objects_async expose iterator return types.""" + from vercel.blob import iter_objects, iter_objects_async + + sync_annotation = inspect.signature(iter_objects).return_annotation + async_annotation = inspect.signature(iter_objects_async).return_annotation + + assert "Iterator" in str(sync_annotation), ( + f"Sync should return Iterator, got {sync_annotation}" + ) + assert "AsyncIterator" in str(async_annotation), ( + f"Async should return AsyncIterator, got {async_annotation}" + ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_blob_client_lifecycle.py b/tests/unit/test_blob_client_lifecycle.py new file mode 100644 index 0000000..edb6b4c --- /dev/null +++ b/tests/unit/test_blob_client_lifecycle.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from vercel.blob.client import AsyncBlobClient, BlobClient +from vercel.blob.errors import BlobError +from vercel.blob.types import HeadBlobResult, ListBlobResult + + +def _head_result() -> HeadBlobResult: + return HeadBlobResult( + size=1, + uploaded_at=datetime.now(timezone.utc), + pathname="file.txt", + content_type="text/plain", + content_disposition="inline", + url="https://blob.vercel-storage.com/file.txt", + download_url="https://blob.vercel-storage.com/file.txt?download=1", + cache_control="public, max-age=3600", + ) + + +def _list_result() -> ListBlobResult: + return ListBlobResult(blobs=[], cursor=None, has_more=False) + + +class TestBlobClientLifecycle: + def test_sync_client_reuses_owned_ops_client(self) -> None: + mock_ops_client = MagicMock() + mock_ops_client.head_blob = AsyncMock(return_value=_head_result()) + mock_ops_client.list_objects = MagicMock(return_value=_list_result()) + + with patch("vercel.blob.client.SyncBlobOpsClient", return_value=mock_ops_client) as ctor: + client = BlobClient(token="test_token") + client.head("file.txt") + client.list_objects() + + assert ctor.call_count == 1 + mock_ops_client.head_blob.assert_awaited_once() + mock_ops_client.list_objects.assert_called_once() + + def test_sync_close_is_idempotent_and_blocks_use_after_close(self) -> None: + mock_ops_client = MagicMock() + mock_ops_client.head_blob = AsyncMock(return_value=_head_result()) + + with patch("vercel.blob.client.SyncBlobOpsClient", return_value=mock_ops_client): + client = BlobClient(token="test_token") + client.close() + client.close() + + with pytest.raises(BlobError, match="Client is closed"): + client.head("file.txt") + + mock_ops_client.close.assert_called_once() + mock_ops_client.head_blob.assert_not_called() + + def test_sync_client_multipart_uploader_uses_ownedrequest_api(self) -> None: + actions: list[str] = [] + + async def request_api(**kwargs): + action = kwargs["headers"]["x-mpu-action"] + actions.append(action) + if action == "create": + return {"uploadId": "upload-id", "key": "blob-key"} + if action == "upload": + return {"etag": "etag-1"} + return { + "url": "https://blob.vercel-storage.com/test-abc123/folder/client-mpu.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/test-abc123/folder/client-mpu.bin?download=1" + ), + "pathname": "folder/client-mpu.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="client-mpu.bin"', + } + + mock_request_client = MagicMock() + mock_request_client.request_api = AsyncMock(side_effect=request_api) + mock_ops_client = MagicMock() + mock_ops_client._request_client = mock_request_client + + with patch("vercel.blob.client.SyncBlobOpsClient", return_value=mock_ops_client): + client = BlobClient(token="test_token") + uploader = client.create_multipart_uploader("folder/client-mpu.bin") + part = uploader.upload_part(1, b"chunk") + result = uploader.complete([part]) + + assert actions == ["create", "upload", "complete"] + assert mock_request_client.request_api.await_count == 3 + assert result.pathname == "folder/client-mpu.bin" + + @pytest.mark.asyncio + async def test_async_client_reuses_owned_ops_client(self) -> None: + mock_ops_client = MagicMock() + mock_ops_client.head_blob = AsyncMock(return_value=_head_result()) + mock_ops_client.list_objects = AsyncMock(return_value=_list_result()) + + with patch("vercel.blob.client.AsyncBlobOpsClient", return_value=mock_ops_client) as ctor: + client = AsyncBlobClient(token="test_token") + await client.head("file.txt") + await client.list_objects() + + assert ctor.call_count == 1 + mock_ops_client.head_blob.assert_awaited_once() + mock_ops_client.list_objects.assert_awaited_once() + + @pytest.mark.asyncio + async def test_async_close_is_idempotent_and_blocks_use_after_close(self) -> None: + mock_ops_client = MagicMock() + mock_ops_client.aclose = AsyncMock() + mock_ops_client.head_blob = AsyncMock(return_value=_head_result()) + + with patch("vercel.blob.client.AsyncBlobOpsClient", return_value=mock_ops_client): + client = AsyncBlobClient(token="test_token") + await client.aclose() + await client.aclose() + + with pytest.raises(BlobError, match="Client is closed"): + await client.head("file.txt") + + mock_ops_client.aclose.assert_awaited_once() + mock_ops_client.head_blob.assert_not_called() + + @pytest.mark.asyncio + async def test_async_client_multipart_uploader_uses_ownedrequest_api(self) -> None: + actions: list[str] = [] + + async def request_api(**kwargs): + action = kwargs["headers"]["x-mpu-action"] + actions.append(action) + if action == "create": + return {"uploadId": "upload-id", "key": "blob-key"} + if action == "upload": + return {"etag": "etag-1"} + return { + "url": "https://blob.vercel-storage.com/test-abc123/folder/client-mpu-async.bin", + "downloadUrl": ( + "https://blob.vercel-storage.com/test-abc123/folder/" + "client-mpu-async.bin?download=1" + ), + "pathname": "folder/client-mpu-async.bin", + "contentType": "application/octet-stream", + "contentDisposition": 'inline; filename="client-mpu-async.bin"', + } + + mock_request_client = MagicMock() + mock_request_client.request_api = AsyncMock(side_effect=request_api) + mock_ops_client = MagicMock() + mock_ops_client._request_client = mock_request_client + + with patch("vercel.blob.client.AsyncBlobOpsClient", return_value=mock_ops_client): + client = AsyncBlobClient(token="test_token") + uploader = await client.create_multipart_uploader("folder/client-mpu-async.bin") + part = await uploader.upload_part(1, b"chunk") + result = await uploader.complete([part]) + + assert actions == ["create", "upload", "complete"] + assert mock_request_client.request_api.await_count == 3 + assert result.pathname == "folder/client-mpu-async.bin" diff --git a/tests/unit/test_client_instantiation.py b/tests/unit/test_client_instantiation.py new file mode 100644 index 0000000..54a79e8 --- /dev/null +++ b/tests/unit/test_client_instantiation.py @@ -0,0 +1,101 @@ +""" +Unit tests for client class instantiation. + +These tests verify that client classes can be instantiated without errors. +""" + +import os +from unittest.mock import patch + +import pytest + + +class TestClientInstantiation: + """Test that all client classes can be instantiated.""" + + @pytest.fixture + def mock_env_token(self): + """Provide a mock token via environment variable.""" + with patch.dict(os.environ, {"VERCEL_TOKEN": "test_token"}): + yield + + def test_projects_client_instantiation(self, mock_env_token): + """Test ProjectsClient can be instantiated.""" + from vercel.projects.client import ProjectsClient + + client = ProjectsClient() + assert client is not None + assert hasattr(client, "_access_token") + assert hasattr(client, "_base_url") + assert hasattr(client, "_timeout") + + def test_projects_client_with_token(self): + """Test ProjectsClient can be instantiated with explicit token.""" + from vercel.projects.client import ProjectsClient + + client = ProjectsClient(access_token="explicit_token") + assert client is not None + assert client._access_token == "explicit_token" + + def test_async_projects_client_instantiation(self, mock_env_token): + """Test AsyncProjectsClient can be instantiated.""" + from vercel.projects.client import AsyncProjectsClient + + client = AsyncProjectsClient() + assert client is not None + assert hasattr(client, "_access_token") + assert hasattr(client, "_base_url") + assert hasattr(client, "_timeout") + + def test_deployments_client_instantiation(self, mock_env_token): + """Test DeploymentsClient can be instantiated.""" + from vercel.deployments.client import DeploymentsClient + + client = DeploymentsClient() + assert client is not None + assert hasattr(client, "_access_token") + assert hasattr(client, "_base_url") + assert hasattr(client, "_timeout") + + def test_deployments_client_with_token(self): + """Test DeploymentsClient can be instantiated with explicit token.""" + from vercel.deployments.client import DeploymentsClient + + client = DeploymentsClient(access_token="explicit_token") + assert client is not None + assert client._access_token == "explicit_token" + + def test_async_deployments_client_instantiation(self, mock_env_token): + """Test AsyncDeploymentsClient can be instantiated.""" + from vercel.deployments.client import AsyncDeploymentsClient + + client = AsyncDeploymentsClient() + assert client is not None + assert hasattr(client, "_access_token") + assert hasattr(client, "_base_url") + assert hasattr(client, "_timeout") + + def test_build_cache_instantiation(self): + """Test BuildCache can be instantiated.""" + from vercel.cache.cache_build import BuildCache + + client = BuildCache( + endpoint="https://cache.example.com", + headers={"Authorization": "Bearer test"}, + ) + assert client is not None + assert hasattr(client, "_endpoint") + assert hasattr(client, "_headers") + assert hasattr(client, "_client") + + def test_async_build_cache_instantiation(self): + """Test AsyncBuildCache can be instantiated.""" + from vercel.cache.cache_build import AsyncBuildCache + + client = AsyncBuildCache( + endpoint="https://cache.example.com", + headers={"Authorization": "Bearer test"}, + ) + assert client is not None + assert hasattr(client, "_endpoint") + assert hasattr(client, "_headers") diff --git a/tests/unit/test_iter_coroutine.py b/tests/unit/test_iter_coroutine.py new file mode 100644 index 0000000..409e768 --- /dev/null +++ b/tests/unit/test_iter_coroutine.py @@ -0,0 +1,39 @@ +import pytest + +from vercel._internal.iter_coroutine import iter_coroutine + + +class _SuspendingAwaitable: + def __await__(self): + yield None + return None + + +def test_iter_coroutine_returns_result_and_closes_coroutine() -> None: + closed = False + + async def coro() -> str: + nonlocal closed + try: + return "ok" + finally: + closed = True + + assert iter_coroutine(coro()) == "ok" + assert closed + + +def test_iter_coroutine_raises_on_suspending_coroutine_and_closes_coroutine() -> None: + closed = False + + async def coro() -> None: + nonlocal closed + try: + await _SuspendingAwaitable() + finally: + closed = True + + with pytest.raises(RuntimeError, match="did not stop after one iteration"): + iter_coroutine(coro()) + + assert closed