diff --git a/README.md b/README.md index fbc8526..86b0380 100644 --- a/README.md +++ b/README.md @@ -503,6 +503,369 @@ replicate = Client( > Never hardcode authentication credentials like API tokens into your code. > Instead, pass them as environment variables when running your program. +## Experimental `use()` interface + +The latest versions of `replicate >= 1.0.8` include a new experimental `use()` function that is intended to make running a model closer to calling a function rather than an API request. + +Some key differences to `replicate.run()`. + + 1. You "import" the model using the `use()` syntax, after that you call the model like a function. + 2. The output type matches the model definition. + 3. Baked in support for streaming for all models. + 4. File outputs will be represented as `PathLike` objects and downloaded to disk when used*. + +> [!NOTE] +> \* We've replaced the `FileOutput` implementation with `Path` objects. However to avoid unnecessary downloading of files until they are needed we've implemented a `PathProxy` class that will defer the download until the first time the object is used. If you need the underlying URL of the `Path` object you can use the `get_path_url(path: Path) -> str` helper. + +### Examples + +To use a model: + +> [!IMPORTANT] +> For now `use()` MUST be called in the top level module scope. We may relax this in future. + +```py +import replicate + +flux_dev = replicate.use("black-forest-labs/flux-dev") +outputs = flux_dev(prompt="a cat wearing an amusing hat") + +for output in outputs: + print(output) # Path(/tmp/output.webp) +``` + +Models that implement iterators will return the output of the completed run as a list unless explicitly streaming (see Streaming section below). Language models that define `x-cog-iterator-display: concatenate` will return strings: + +```py +claude = replicate.use("anthropic/claude-4-sonnet") + +output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") + +print(output) # "Here's a recipe to feed all of California (about 39 million people)! ..." +``` + +You can pass the results of one model directly into another: + +```py +import replicate + +flux_dev = replicate.use("black-forest-labs/flux-dev") +claude = replicate.use("anthropic/claude-4-sonnet") + +images = flux_dev(prompt="a cat wearing an amusing hat") + +result = claude(prompt="describe this image for me", image=images[0]) + +print(str(result)) # "This shows an image of a cat wearing a hat ..." +``` + +To create an individual prediction that has not yet resolved, use the `create()` method: + +``` +claude = replicate.use("anthropic/claude-4-sonnet") + +prediction = claude.create(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") + +prediction.logs() # get current logs (WIP) + +prediction.output() # get the output +``` + +### Streaming + +Many models, particularly large language models (LLMs), will yield partial results as the model is running. To consume outputs from these models as they run you can pass the `streaming` argument to `use()`: + +```py +claude = replicate.use("anthropic/claude-4-sonnet", streaming=True) + +output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") + +for chunk in output: + print(chunk) # "Here's a recipe ", "to feed all", " of California" +``` + +### Downloading file outputs + +Output files are provided as Python [os.PathLike](https://docs.python.org/3.12/library/os.html#os.PathLike) objects. These are supported by most of the Python standard library like `open()` and `Path`, as well as third-party libraries like `pillow` and `ffmpeg-python`. + +The first time the file is accessed it will be downloaded to a temporary directory on disk ready for use. + +Here's an example of how to use the `pillow` package to convert file outputs: + +```py +import replicate +from PIL import Image + +flux_dev = replicate.use("black-forest-labs/flux-dev") + +images = flux_dev(prompt="a cat wearing an amusing hat") +for i, path in enumerate(images): + with Image.open(path) as img: + img.save(f"./output_{i}.png", format="PNG") +``` + +For libraries that do not support `Path` or `PathLike` instances you can use `open()` as you would with any other file. For example to use `requests` to upload the file to a different location: + +```py +import replicate +import requests + +flux_dev = replicate.use("black-forest-labs/flux-dev") + +images = flux_dev(prompt="a cat wearing an amusing hat") +for path in images: + with open(path, "rb") as f: + r = requests.post("https://api.example.com/upload", files={"file": f}) +``` + +### Accessing outputs as HTTPS URLs + +If you do not need to download the output to disk. You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper. + +```py +import replicate +from replicate import get_url_path + +flux_dev = replicate.use("black-forest-labs/flux-dev") +outputs = flux_dev(prompt="a cat wearing an amusing hat") + +for output in outputs: + print(get_url_path(output)) # "https://replicate.delivery/xyz" +``` + +### Async Mode + +By default `use()` will return a function instance with a sync interface. You can pass `use_async=True` to have it return an `AsyncFunction` that provides an async interface. + +```py +import asyncio +import replicate + +async def main(): + flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True) + outputs = await flux_dev(prompt="a cat wearing an amusing hat") + + for output in outputs: + print(Path(output)) + +asyncio.run(main()) +``` + +When used in streaming mode then an `AsyncIterator` will be returned. + +```py +import asyncio +import replicate + +async def main(): + claude = replicate.use("anthropic/claude-3.5-haiku", streaming=True, use_async=True) + output = await claude(prompt="say hello") + + # Stream the response as it comes in. + async for token in output: + print(token) + + # Wait until model has completed. This will return either a `list` or a `str` depending + # on whether the model uses AsyncIterator or ConcatenateAsyncIterator. You can check this + # on the model schema by looking for `x-cog-display: concatenate`. + print(await output) + +asyncio.run(main()) +``` + +### Typing + +By default `use()` knows nothing about the interface of the model. To provide a better developer experience we provide two methods to add type annotations to the function returned by the `use()` helper. + +**1. Provide a function signature** + +The use method accepts a function signature as an additional `hint` keyword argument. When provided it will use this signature for the `model()` and `model.create()` functions. + +```py +# Flux takes a required prompt string and optional image and seed. +def hint(*, prompt: str, image: Path | None = None, seed: int | None = None) -> str: ... + +flux_dev = use("black-forest-labs/flux-dev", hint=hint) +output1 = flux_dev() # will warn that `prompt` is missing +output2 = flux_dev(prompt="str") # output2 will be typed as `str` +``` + +**2. Provide a class** + +The second method requires creating a callable class with a `name` field. The name will be used as the function reference when passed to `use()`. + +```py +class FluxDev: + name = "black-forest-labs/flux-dev" + + def __call__( self, *, prompt: str, image: Path | None = None, seed: int | None = None ) -> str: ... + +flux_dev = use(FluxDev) +output1 = flux_dev() # will warn that `prompt` is missing +output2 = flux_dev(prompt="str") # output2 will be typed as `str` +``` + +> [!WARNING] +> Currently the typing system doesn't correctly support the `streaming` flag for models that return lists or use iterators. We're working on improvements here. + +In future we hope to provide tooling to generate and provide these models as packages to make working with them easier. For now you may wish to create your own. + +### API Reference + +The Replicate Python Library provides several key classes and functions for working with models in pipelines: + +#### `use()` Function + +Creates a callable function wrapper for a Replicate model. + +```py +def use( + ref: FunctionRef, + *, + streaming: bool = False, + use_async: bool = False +) -> Function | AsyncFunction + +def use( + ref: str, + *, + hint: Callable | None = None, + streaming: bool = False, + use_async: bool = False +) -> Function | AsyncFunction +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `ref` | `str \| FunctionRef` | Required | Model reference (e.g., "owner/model" or "owner/model:version") | +| `hint` | `Callable \| None` | `None` | Function signature for type hints | +| `streaming` | `bool` | `False` | Return OutputIterator for streaming results | +| `use_async` | `bool` | `False` | Return AsyncFunction instead of Function | + +**Returns:** +- `Function` - Synchronous model wrapper (default) +- `AsyncFunction` - Asynchronous model wrapper (when `use_async=True`) + +#### `Function` Class + +A synchronous wrapper for calling Replicate models. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__call__()` | `(*args, **inputs) -> Output` | Execute the model and return final output | +| `create()` | `(*args, **inputs) -> Run` | Start a prediction and return Run object | + +**Properties:** + +| Property | Type | Description | +|----------|------|-------------| +| `openapi_schema` | `dict` | Model's OpenAPI schema for inputs/outputs | +| `default_example` | `dict \| None` | Default example inputs (not yet implemented) | + +#### `AsyncFunction` Class + +An asynchronous wrapper for calling Replicate models. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__call__()` | `async (*args, **inputs) -> Output` | Execute the model and return final output | +| `create()` | `async (*args, **inputs) -> AsyncRun` | Start a prediction and return AsyncRun object | + +**Properties:** + +| Property | Type | Description | +|----------|------|-------------| +| `openapi_schema()` | `async () -> dict` | Model's OpenAPI schema for inputs/outputs | +| `default_example` | `dict \| None` | Default example inputs (not yet implemented) | + +#### `Run` Class + +Represents a running prediction with access to output and logs. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `output()` | `() -> Output` | Get prediction output (blocks until complete) | +| `logs()` | `() -> str \| None` | Get current prediction logs | + +**Behavior:** +- When `streaming=True`: Returns `OutputIterator` immediately +- When `streaming=False`: Waits for completion and returns final result + +#### `AsyncRun` Class + +Asynchronous version of Run for async model calls. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `output()` | `async () -> Output` | Get prediction output (awaits completion) | +| `logs()` | `async () -> str \| None` | Get current prediction logs | + +#### `OutputIterator` Class + +Iterator wrapper for streaming model outputs. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__iter__()` | `() -> Iterator[T]` | Synchronous iteration over output chunks | +| `__aiter__()` | `() -> AsyncIterator[T]` | Asynchronous iteration over output chunks | +| `__str__()` | `() -> str` | Convert to string (concatenated or list representation) | +| `__await__()` | `() -> List[T] \| str` | Await all results (string for concatenate, list otherwise) | + +#### `URLPath` Class + +A path-like object that downloads files on first access. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__fspath__()` | `() -> str` | Get local file path (downloads if needed) | +| `__str__()` | `() -> str` | String representation of local path | + +**Usage:** +- Compatible with `open()`, `pathlib.Path()`, and most file operations +- Downloads file automatically on first filesystem access +- Cached locally in temporary directory + +#### `get_path_url()` Function + +Helper function to extract original URLs from `URLPath` objects. + +```py +def get_path_url(path: Any) -> str | None +``` + +**Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `path` | `Any` | Path object (typically `URLPath`) | + +**Returns:** +- `str` - Original URL if path is a `URLPath` +- `None` - If path is not a `URLPath` or has no URL + +### TODO + +There are several key things still outstanding: + + 1. Support for streaming text when available (rather than polling) + 2. Support for streaming files when available (rather than polling) + 3. Support for cleaning up downloaded files. + 4. Support for streaming logs using `OutputIterator`. + ## Development See [CONTRIBUTING.md](CONTRIBUTING.md) diff --git a/pyproject.toml b/pyproject.toml index 586d919..dfea1c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dev-dependencies = [ [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" testpaths = "tests/" [tool.setuptools] @@ -73,8 +74,6 @@ ignore = [ "ANN001", # Missing type annotation for function argument "ANN002", # Missing type annotation for `*args` "ANN003", # Missing type annotation for `**kwargs` - "ANN101", # Missing type annotation for self in method - "ANN102", # Missing type annotation for cls in classmethod "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name} "W191", # Indentation contains tabs "UP037", # Remove quotes from type annotation diff --git a/replicate/__init__.py b/replicate/__init__.py index 0e6838d..a4a1a86 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -1,6 +1,28 @@ from replicate.client import Client from replicate.pagination import async_paginate as _async_paginate from replicate.pagination import paginate as _paginate +from replicate.use import get_path_url, use + +__all__ = [ + "Client", + "use", + "run", + "async_run", + "stream", + "async_stream", + "paginate", + "async_paginate", + "collections", + "deployments", + "files", + "hardware", + "models", + "predictions", + "trainings", + "webhooks", + "default_client", + "get_path_url", +] default_client = Client() diff --git a/replicate/client.py b/replicate/client.py index 6a79813..e4e0e9e 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -352,6 +352,11 @@ def _get_api_token_from_environment() -> Optional[str]: """Get API token from cog current scope if available, otherwise from environment.""" try: import cog # noqa: I001 # pyright: ignore [reportMissingImports] + import warnings + + warnings.filterwarnings( + "ignore", message="current_scope", category=cog.ExperimentalFeatureWarning + ) for key, value in cog.current_scope().context.items(): if key.upper() == "REPLICATE_API_TOKEN": diff --git a/replicate/prediction.py b/replicate/prediction.py index b4ff047..5cee42f 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -248,6 +248,11 @@ def output_iterator(self) -> Iterator[Any]: """ Return an iterator of the prediction output. """ + if ( + self.status in ["succeeded", "failed", "canceled"] + and self.output is not None + ): + yield from self.output # TODO: check output is list previous_output = self.output or [] @@ -270,6 +275,12 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: """ Return an asynchronous iterator of the prediction output. """ + if ( + self.status in ["succeeded", "failed", "canceled"] + and self.output is not None + ): + for item in self.output: + yield item # TODO: check output is list previous_output = self.output or [] diff --git a/replicate/schema.py b/replicate/schema.py index 06f9f05..82f06c3 100644 --- a/replicate/schema.py +++ b/replicate/schema.py @@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]: def make_schema_backwards_compatible( schema: dict, - cog_version: str, + cog_version: str | None, ) -> dict: """A place to add backwards compatibility logic for our openapi schema""" # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type - if version_has_no_array_type(cog_version): + if cog_version and version_has_no_array_type(cog_version): output = schema["components"]["schemas"]["Output"] if output.get("type") == "array": output["x-cog-array-type"] = "iterator" diff --git a/replicate/use.py b/replicate/use.py new file mode 100644 index 0000000..50c9ca6 --- /dev/null +++ b/replicate/use.py @@ -0,0 +1,753 @@ +# TODO +# - [ ] Support text streaming +# - [ ] Support file streaming +import hashlib +import os +import tempfile +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from typing import ( + Any, + AsyncIterator, + Callable, + Generator, + Generic, + Iterator, + List, + Literal, + Optional, + ParamSpec, + Protocol, + Tuple, + TypeVar, + cast, + overload, +) +from urllib.parse import urlparse + +import httpx + +from replicate.client import Client +from replicate.exceptions import ModelError, ReplicateError +from replicate.identifier import ModelVersionIdentifier +from replicate.model import Model +from replicate.prediction import Prediction +from replicate.run import make_schema_backwards_compatible +from replicate.version import Version + +__all__ = ["use", "get_path_url"] + + +def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: + """ + Returns true if the model output type is ConcatenateIterator or + AsyncConcatenateIterator. + """ + output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + + if output.get("type") != "array": + return False + + if output.get("items", {}).get("type") != "string": + return False + + if output.get("x-cog-array-type") != "iterator": + return False + + if output.get("x-cog-array-display") != "concatenate": + return False + + return True + + +def _has_iterator_output_type(openapi_schema: dict) -> bool: + """ + Returns true if the model output type is an iterator (non-concatenate). + """ + output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + return ( + output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" + ) + + +def _download_file(url: str) -> Path: + """ + Download a file from URL to a temporary location and return the Path. + """ + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path) + + if not filename or "." not in filename: + filename = "download" + + _, ext = os.path.splitext(filename) + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: + with httpx.stream("GET", url) as response: + response.raise_for_status() + for chunk in response.iter_bytes(): + temp_file.write(chunk) + + return Path(temp_file.name) + + +def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: + """ + Process a single item from an iterator output based on schema. + """ + output_schema = ( + openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + ) + + # For array/iterator types, check the items schema + if ( + output_schema.get("type") == "array" + and output_schema.get("x-cog-array-type") == "iterator" + ): + items_schema = output_schema.get("items", {}) + # If items are file URLs, download them + if items_schema.get("type") == "string" and items_schema.get("format") == "uri": + if isinstance(item, str) and item.startswith(("http://", "https://")): + return URLPath(item) + + return item + + +def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # pylint: disable=too-many-branches,too-many-nested-blocks + """ + Process output data, downloading files based on OpenAPI schema. + """ + output_schema = ( + openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + ) + + # Handle direct string with format=uri + if output_schema.get("type") == "string" and output_schema.get("format") == "uri": + if isinstance(output, str) and output.startswith(("http://", "https://")): + return URLPath(output) + return output + + # Handle array of strings with format=uri + if output_schema.get("type") == "array": + items = output_schema.get("items", {}) + if items.get("type") == "string" and items.get("format") == "uri": + if isinstance(output, list): + return [ + URLPath(url) + if isinstance(url, str) and url.startswith(("http://", "https://")) + else url + for url in output + ] + return output + + # Handle object with properties + if output_schema.get("type") == "object" and isinstance(output, dict): # pylint: disable=too-many-nested-blocks + properties = output_schema.get("properties", {}) + result = output.copy() + + for prop_name, prop_schema in properties.items(): + if prop_name in result: + value = result[prop_name] + + # Direct file property + if ( + prop_schema.get("type") == "string" + and prop_schema.get("format") == "uri" + ): + if isinstance(value, str) and value.startswith( + ("http://", "https://") + ): + result[prop_name] = URLPath(value) + + # Array of files property + elif prop_schema.get("type") == "array": + items = prop_schema.get("items", {}) + if items.get("type") == "string" and items.get("format") == "uri": + if isinstance(value, list): + result[prop_name] = [ + URLPath(url) + if isinstance(url, str) + and url.startswith(("http://", "https://")) + else url + for url in value + ] + + return result + + return output + + +T = TypeVar("T") + + +class OutputIterator[T]: + """ + An iterator wrapper that handles both regular iteration and string conversion. + Supports both sync and async iteration patterns. + """ + + def __init__( + self, + iterator_factory: Callable[[], Iterator[T]], + async_iterator_factory: Callable[[], AsyncIterator[T]], + schema: dict, + *, + is_concatenate: bool, + ) -> None: + self.iterator_factory = iterator_factory + self.async_iterator_factory = async_iterator_factory + self.schema = schema + self.is_concatenate = is_concatenate + + def __iter__(self) -> Iterator[T]: + """Iterate over output items synchronously.""" + for chunk in self.iterator_factory(): + if self.is_concatenate: + yield chunk + else: + yield _process_iterator_item(chunk, self.schema) + + async def __aiter__(self) -> AsyncIterator[T]: + """Iterate over output items asynchronously.""" + async for chunk in self.async_iterator_factory(): + if self.is_concatenate: + yield chunk + else: + yield _process_iterator_item(chunk, self.schema) + + def __str__(self) -> str: + """Convert to string by joining segments with empty string.""" + if self.is_concatenate: + return "".join([str(segment) for segment in self.iterator_factory()]) + return str(list(self.iterator_factory())) + + def __await__(self) -> Generator[Any, None, List[T] | str]: + """Make OutputIterator awaitable, returning appropriate result based on concatenate mode.""" + + async def _collect_result() -> List[T] | str: + if self.is_concatenate: + # For concatenate iterators, return the joined string + segments = [] + async for segment in self: + segments.append(segment) + return "".join(segments) + # For regular iterators, return the list of items + items = [] + async for item in self: + items.append(item) + return items + + return _collect_result().__await__() # pylint: disable=no-member # return type confuses pylint + + +class URLPath(os.PathLike): + """ + A PathLike that defers filesystem ops until first use. Can be used with + most Python file interfaces like `open()` and `pathlib.Path()`. + See: https://docs.python.org/3.12/library/os.html#os.PathLike + """ + + def __init__(self, url: str) -> None: + # store the original URL + self.__url__ = url + + # compute target path without touching the filesystem + base = Path(tempfile.gettempdir()) + h = hashlib.sha256(self.__url__.encode("utf-8")).hexdigest()[:16] + name = Path(httpx.URL(self.__url__).path).name or h + self.__path__ = base / h / name + + def __fspath__(self) -> str: + # on first access, create dirs and download if missing + if not self.__path__.exists(): + subdir = self.__path__.parent + subdir.mkdir(parents=True, exist_ok=True) + if not os.access(subdir, os.W_OK): + raise PermissionError(f"Cannot write to {subdir!r}") + + with httpx.Client() as client, client.stream("GET", self.__url__) as resp: + resp.raise_for_status() + with open(self.__path__, "wb") as f: + for chunk in resp.iter_bytes(chunk_size=16_384): + f.write(chunk) + + return str(self.__path__) + + def __str__(self) -> str: + return self.__fspath__() + + def __repr__(self) -> str: + return f"" + + +def get_path_url(path: Any) -> str | None: + """ + Return the remote URL (if any) for a Path output from a model. + """ + try: + return object.__getattribute__(path, "__url__") + except AttributeError: + return None + + +Input = ParamSpec("Input") +Output = TypeVar("Output") + + +class FunctionRef(Protocol, Generic[Input, Output]): + """Represents a Replicate model, providing the model identifier and interface.""" + + name: str + + __call__: Callable[Input, Output] + + +@dataclass +class Run[O]: + """ + Represents a running prediction with access to the underlying schema. + """ + + _prediction: Prediction + _schema: dict + + def __init__( + self, *, prediction: Prediction, schema: dict, streaming: bool + ) -> None: + self._prediction = prediction + self._schema = schema + self._streaming = streaming + + def output(self) -> O: + """ + Return the output. For iterator types, returns immediately without waiting. + For non-iterator types, waits for completion. + """ + # Return an OutputIterator immediately when streaming, we do this for all + # model return types regardless of whether they return an iterator. + if self._streaming: + is_concatenate = _has_concatenate_iterator_output_type(self._schema) + return cast( + O, + OutputIterator( + self._prediction.output_iterator, + self._prediction.async_output_iterator, + self._schema, + is_concatenate=is_concatenate, + ), + ) + + # For non-streaming, wait for completion and process output + self._prediction.wait() + + if self._prediction.status == "failed": + raise ModelError(self._prediction) + + # Handle concatenate iterators - return joined string + if _has_concatenate_iterator_output_type(self._schema): + if isinstance(self._prediction.output, list): + return cast(O, "".join(str(item) for item in self._prediction.output)) + return self._prediction.output + + # Process output for file downloads based on schema + return _process_output_with_schema(self._prediction.output, self._schema) + + def logs(self) -> Optional[str]: + """ + Fetch and return the logs from the prediction. + """ + self._prediction.reload() + + return self._prediction.logs + + +@dataclass +class Function(Generic[Input, Output]): + """ + A wrapper for a Replicate model that can be called as a function. + """ + + _ref: str + + def __init__(self, ref: str, *, streaming: bool) -> None: + self._ref = ref + self._streaming = streaming + + def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: + return self.create(*args, **inputs).output() + + def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: + """ + Start a prediction with the specified inputs. + """ + # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs + processed_inputs = {} + for key, value in inputs.items(): + if isinstance(value, OutputIterator): + if value.is_concatenate: + processed_inputs[key] = str(value) + else: + processed_inputs[key] = list(value) + elif url := get_path_url(value): + processed_inputs[key] = url + else: + processed_inputs[key] = value + + version = self._version + + if version: + prediction = self._client().predictions.create( + version=version, input=processed_inputs + ) + else: + prediction = self._client().models.predictions.create( + model=self._model, input=processed_inputs + ) + + return Run( + prediction=prediction, schema=self.openapi_schema, streaming=self._streaming + ) + + @property + def default_example(self) -> Optional[dict[str, Any]]: + """ + Get the default example for this model. + """ + raise NotImplementedError("This property has not yet been implemented") + + @cached_property + def openapi_schema(self) -> dict[str, Any]: + """ + Get the OpenAPI schema for this model version. + """ + latest_version = self._model.latest_version + if latest_version is None: + msg = f"Model {self._model.owner}/{self._model.name} has no latest version" + raise ValueError(msg) + + schema = latest_version.openapi_schema + if cog_version := latest_version.cog_version: + schema = make_schema_backwards_compatible(schema, cog_version) + return schema + + def _client(self) -> Client: + return Client() + + @cached_property + def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: + return ModelVersionIdentifier.parse(self._ref) + + @cached_property + def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._parsed_ref + return client.models.get(f"{model_owner}/{model_name}") + + @cached_property + def _version(self) -> Version | None: + _, _, model_version = self._parsed_ref + model = self._model + try: + versions = model.versions.list() + if len(versions) == 0: + # if we got an empty list when getting model versions, this + # model is possibly a procedure instead and should be called via + # the versionless API + return None + except ReplicateError as e: + if e.status == 404: + # if we get a 404 when getting model versions, this is an official + # model and doesn't have addressable versions (despite what + # latest_version might tell us) + return None + raise + + version = ( + model.versions.get(model_version) if model_version else model.latest_version + ) + + return version + + +@dataclass +class AsyncRun[O]: + """ + Represents a running prediction with access to its version (async version). + """ + + _prediction: Prediction + _schema: dict + + def __init__( + self, *, prediction: Prediction, schema: dict, streaming: bool + ) -> None: + self._prediction = prediction + self._schema = schema + self._streaming = streaming + + async def output(self) -> O: + """ + Return the output. For iterator types, returns immediately without waiting. + For non-iterator types, waits for completion. + """ + # Return an OutputIterator immediately when streaming, we do this for all + # model return types regardless of whether they return an iterator. + if self._streaming: + is_concatenate = _has_concatenate_iterator_output_type(self._schema) + return cast( + O, + OutputIterator( + self._prediction.output_iterator, + self._prediction.async_output_iterator, + self._schema, + is_concatenate=is_concatenate, + ), + ) + + # For non-streaming, wait for completion and process output + await self._prediction.async_wait() + + if self._prediction.status == "failed": + raise ModelError(self._prediction) + + # Handle concatenate iterators - return joined string + if _has_concatenate_iterator_output_type(self._schema): + if isinstance(self._prediction.output, list): + return cast(O, "".join(str(item) for item in self._prediction.output)) + return self._prediction.output + + # Process output for file downloads based on schema + return _process_output_with_schema(self._prediction.output, self._schema) + + async def logs(self) -> Optional[str]: + """ + Fetch and return the logs from the prediction asynchronously. + """ + await self._prediction.async_reload() + + return self._prediction.logs + + +@dataclass +class AsyncFunction(Generic[Input, Output]): + """ + An async wrapper for a Replicate model that can be called as a function. + """ + + function_ref: str + streaming: bool + + def _client(self) -> Client: + return Client() + + @cached_property + def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: + return ModelVersionIdentifier.parse(self.function_ref) + + async def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._parsed_ref + return await client.models.async_get(f"{model_owner}/{model_name}") + + async def _version(self) -> Version | None: + _, _, model_version = self._parsed_ref + model = await self._model() + try: + versions = await model.versions.async_list() + if len(versions) == 0: + # if we got an empty list when getting model versions, this + # model is possibly a procedure instead and should be called via + # the versionless API + return None + except ReplicateError as e: + if e.status == 404: + # if we get a 404 when getting model versions, this is an official + # model and doesn't have addressable versions (despite what + # latest_version might tell us) + return None + raise + + if model_version: + version = await model.versions.async_get(model_version) + else: + version = model.latest_version + + return version + + async def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: + run = await self.create(*args, **inputs) + return await run.output() + + async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Output]: + """ + Start a prediction with the specified inputs asynchronously. + """ + # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs + processed_inputs = {} + for key, value in inputs.items(): + if isinstance(value, OutputIterator): + processed_inputs[key] = await value + elif url := get_path_url(value): + processed_inputs[key] = url + else: + processed_inputs[key] = value + + version = await self._version() + + if version: + prediction = await self._client().predictions.async_create( + version=version, input=processed_inputs + ) + else: + model = await self._model() + prediction = await self._client().models.predictions.async_create( + model=model, input=processed_inputs + ) + + return AsyncRun( + prediction=prediction, + schema=await self.openapi_schema(), + streaming=self.streaming, + ) + + @property + def default_example(self) -> Optional[dict[str, Any]]: + """ + Get the default example for this model. + """ + raise NotImplementedError("This property has not yet been implemented") + + async def openapi_schema(self) -> dict[str, Any]: + """ + Get the OpenAPI schema for this model version asynchronously. + """ + model = await self._model() + latest_version = model.latest_version + if latest_version is None: + msg = f"Model {model.owner}/{model.name} has no latest version" + raise ValueError(msg) + + schema = latest_version.openapi_schema + if cog_version := latest_version.cog_version: + schema = make_schema_backwards_compatible(schema, cog_version) + return schema + + +@overload +def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], *, streaming: Literal[False] +) -> Function[Input, Output]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], *, use_async: Literal[False] +) -> Function[Input, Output]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], *, use_async: Literal[True] +) -> AsyncFunction[Input, Output]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], + *, + streaming: Literal[False], + use_async: Literal[True], +) -> AsyncFunction[Input, Output]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], + *, + streaming: Literal[True], + use_async: Literal[True], +) -> AsyncFunction[Input, AsyncIterator[Output]]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], + *, + streaming: Literal[False], + use_async: Literal[False], +) -> AsyncFunction[Input, AsyncIterator[Output]]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument + streaming: Literal[False] = False, + use_async: Literal[False] = False, +) -> Function[Input, Output]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument + streaming: Literal[True], + use_async: Literal[False] = False, +) -> Function[Input, Iterator[Output]]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument + use_async: Literal[True], +) -> AsyncFunction[Input, Output]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument + streaming: Literal[True], + use_async: Literal[True], +) -> AsyncFunction[Input, AsyncIterator[Output]]: ... + + +def use( + ref: str | FunctionRef[Input, Output], + *, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument # required for type inference + streaming: bool = False, + use_async: bool = False, +) -> ( + Function[Input, Output] + | AsyncFunction[Input, Output] + | Function[Input, Iterator[Output]] + | AsyncFunction[Input, AsyncIterator[Output]] +): + """ + Use a Replicate model as a function. + + Example: + + flux_dev = replicate.use("black-forest-labs/flux-dev") + output = flux_dev(prompt="make me a sandwich") + + """ + try: + ref = ref.name # type: ignore + except AttributeError: + pass + + if use_async: + return AsyncFunction(str(ref), streaming=streaming) + + return Function(str(ref), streaming=streaming) diff --git a/tests/test_client.py b/tests/test_client.py index 6ba6aea..0ea505d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -119,6 +119,9 @@ def mock_send(request): mock_send_wrapper.assert_called_once() +class ExperimentalFeatureWarning(Warning): ... + + class TestGetApiToken: """Test cases for _get_api_token_from_environment function covering all import paths.""" @@ -142,6 +145,7 @@ def test_cog_import_error_falls_back_to_env(self): def test_cog_no_current_scope_method_falls_back_to_env(self): """Test fallback when cog exists but has no current_scope method.""" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning del mock_cog.current_scope # Remove the method with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -152,6 +156,7 @@ def test_cog_no_current_scope_method_falls_back_to_env(self): def test_cog_current_scope_returns_none_falls_back_to_env(self): """Test fallback when current_scope() returns None.""" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = None with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -165,6 +170,7 @@ def test_cog_scope_no_context_attr_falls_back_to_env(self): del mock_scope.context # Remove the context attribute mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -178,6 +184,7 @@ def test_cog_scope_context_not_dict_falls_back_to_env(self): mock_scope.context = "not a dict" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -191,6 +198,7 @@ def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -204,6 +212,7 @@ def test_cog_scope_replicate_api_token_valid_string(self): mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -217,6 +226,7 @@ def test_cog_scope_replicate_api_token_case_insensitive(self): mock_scope.context = {"replicate_api_token": "cog-token"} mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -230,6 +240,7 @@ def test_cog_scope_replicate_api_token_empty_string(self): mock_scope.context = {"replicate_api_token": ""} # Empty string mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -243,6 +254,7 @@ def test_cog_scope_replicate_api_token_none(self): mock_scope.context = {"replicate_api_token": None} mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -253,6 +265,7 @@ def test_cog_scope_replicate_api_token_none(self): def test_cog_current_scope_raises_exception_falls_back_to_env(self): """Test fallback when current_scope() raises an exception.""" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.side_effect = RuntimeError("Scope error") with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): diff --git a/tests/test_use.py b/tests/test_use.py new file mode 100644 index 0000000..70270f7 --- /dev/null +++ b/tests/test_use.py @@ -0,0 +1,1445 @@ +import json +import os +from enum import Enum +from pathlib import Path +from typing import Literal, Union + +import httpx +import pytest +import respx + +import replicate +from replicate.use import get_path_url + + +class ClientMode(str, Enum): + DEFAULT = "default" + ASYNC = "async" + + +# Allow use() to be called in test context +os.environ["REPLICATE_ALWAYS_ALLOW_USE"] = "1" +os.environ["REPLICATE_POLL_INTERVAL"] = "0" + + +def _deep_merge(base, override): + if override is None: + return base + + result = base.copy() + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + +def create_mock_version(version_overrides=None, version_id="xyz123"): + default_version = { + "id": version_id, + "created_at": "2024-01-01T00:00:00Z", + "cog_version": "0.8.0", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "post": { + "summary": "Make a prediction", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionRequest" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionResponse" + } + } + } + } + }, + } + } + }, + "components": { + "schemas": { + "Input": { + "type": "object", + "properties": {"prompt": {"type": "string", "title": "Prompt"}}, + "required": ["prompt"], + }, + "Output": {"type": "string", "title": "Output"}, + } + }, + }, + } + + return _deep_merge(default_version, version_overrides) + + +def create_mock_prediction( + prediction_overrides=None, prediction_id="pred123", uses_versionless_api=None +): + default_prediction = { + "id": prediction_id, + "model": "acme/hotdog-detector", + "version": "hidden" + if uses_versionless_api in ("notfound", "empty") + else "xyz123", + "urls": { + "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", + "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"prompt": "hello world"}, + "output": None, + "error": None, + "logs": "Starting prediction...", + } + + return _deep_merge(default_prediction, prediction_overrides) + + +def mock_model_endpoints( + versions=None, + *, + # This is a workaround while we have a bug in the api + uses_versionless_api: Union[Literal["notfound"], Literal["empty"], None] = None, +): + if versions is None: + versions = [create_mock_version()] + + # Get the latest version (first in list) for the model endpoint + latest_version = versions[0] if versions else None + respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector").mock( + return_value=httpx.Response( + 200, + json={ + "url": "https://replicate.com/acme/hotdog-detector", + "owner": "acme", + "name": "hotdog-detector", + "description": "A model to detect hotdogs", + "visibility": "public", + "github_url": "https://github.com/acme/hotdog-detector", + "paper_url": None, + "license_url": None, + "run_count": 42, + "cover_image_url": None, + "default_example": None, + "latest_version": latest_version, + }, + ) + ) + + versions_results = versions + if uses_versionless_api == "empty": + versions_results = [] + + if uses_versionless_api == "notfound": + respx.get( + "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" + ).mock(return_value=httpx.Response(404, json={"detail": "Not found"})) + else: + respx.get( + "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" + ).mock(return_value=httpx.Response(200, json={"results": versions_results})) + + for version_obj in versions_results: + if uses_versionless_api == "notfound": + respx.get( + f"https://api.replicate.com/v1/models/acme/hotdog-detector/versions/{version_obj['id']}" + ).mock(return_value=httpx.Response(404, json={})) + else: + respx.get( + f"https://api.replicate.com/v1/models/acme/hotdog-detector/versions/{version_obj['id']}" + ).mock(return_value=httpx.Response(200, json=version_obj)) + + +def mock_prediction_endpoints( + predictions=None, + *, + uses_versionless_api=None, +): + if predictions is None: + # Create default two-step prediction flow (processing -> succeeded) + predictions = [ + create_mock_prediction( + { + "status": "processing", + "output": None, + "logs": "", + }, + uses_versionless_api=uses_versionless_api, + ), + create_mock_prediction( + { + "status": "succeeded", + "output": "not hotdog", + "logs": "Starting prediction...\nPrediction completed.", + }, + uses_versionless_api=uses_versionless_api, + ), + ] + + initial_prediction = predictions[0] + if uses_versionless_api in ("notfound", "empty"): + respx.post( + "https://api.replicate.com/v1/models/acme/hotdog-detector/predictions" + ).mock(return_value=httpx.Response(201, json=initial_prediction)) + else: + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response(201, json=initial_prediction) + ) + + prediction_id = initial_prediction["id"] + respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( + side_effect=[httpx.Response(200, json=response) for response in predictions] + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use(client_mode): + mock_model_endpoints() + mock_prediction_endpoints() + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_with_version_identifier(client_mode): + mock_model_endpoints() + mock_prediction_endpoints() + + hotdog_detector = replicate.use( + "acme/hotdog-detector:xyz123", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_with_function_ref(client_mode): + mock_model_endpoints() + mock_prediction_endpoints() + + class HotdogDetector: + name = "acme/hotdog-detector:xyz123" + + def __call__(self, prompt: str) -> str: ... + + hotdog_detector = replicate.use( + HotdogDetector(), use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_versionless_empty_versions_list(client_mode): + mock_model_endpoints(uses_versionless_api="empty") + mock_prediction_endpoints(uses_versionless_api="empty") + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_versionless_404_versions_list(client_mode): + mock_model_endpoints(uses_versionless_api="notfound") + mock_prediction_endpoints(uses_versionless_api="notfound") + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_function_create_method(client_mode): + mock_model_endpoints() + mock_prediction_endpoints() + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + else: + run = hotdog_detector.create(prompt="hello world") + + from replicate.use import AsyncRun, Run + + if client_mode == ClientMode.ASYNC: + assert isinstance(run, AsyncRun) + else: + assert isinstance(run, Run) + assert run._prediction.id == "pred123" + assert run._prediction.status == "processing" + assert run._prediction.input == {"prompt": "hello world"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_concatenate_iterator_output(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["Hello", " ", "world", "!"]} + ), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) + assert str(output) == "Hello world!" + + # Also test that it's iterable + output_list = list(output) + assert output_list == ["Hello", " ", "world", "!"] + + # Test that concatenate OutputIterators are stringified when passed to create() + # Set up a mock for the prediction creation to capture the request + request_body = None + + def capture_request(request): + nonlocal request_body + request_body = request.read() + return httpx.Response( + 201, + json={ + "id": "pred456", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred456", + "cancel": "https://api.replicate.com/v1/predictions/pred456/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"text_input": "Hello world!"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + respx.post("https://api.replicate.com/v1/predictions").mock( + side_effect=capture_request + ) + + # Pass the OutputIterator as input to create() + if client_mode == ClientMode.ASYNC: + await hotdog_detector.create(text_input=output) + else: + hotdog_detector.create(text_input=output) + + # Verify the request body contains the stringified version + assert request_body + parsed_body = json.loads(request_body) + assert parsed_body["input"]["text_input"] == "Hello world!" + + +@pytest.mark.asyncio +async def test_output_iterator_async_iteration(): + """Test OutputIterator async iteration capabilities.""" + from replicate.use import OutputIterator + + # Create mock sync and async iterators + def sync_iterator(): + return iter(["Hello", " ", "world", "!"]) + + async def async_iterator(): + for item in ["Hello", " ", "world", "!"]: + yield item + + # Test concatenate iterator + concatenate_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=True + ) + + # Test sync iteration + sync_result = list(concatenate_output) + assert sync_result == ["Hello", " ", "world", "!"] + + # Test async iteration + async_result = [] + async for item in concatenate_output: + async_result.append(item) + assert async_result == ["Hello", " ", "world", "!"] + + # Test sync string conversion + assert str(concatenate_output) == "Hello world!" + + # Test async await (should return joined string for concatenate) + async_result = await concatenate_output + assert async_result == "Hello world!" + + +@pytest.mark.asyncio +async def test_output_iterator_async_non_concatenate(): + """Test OutputIterator async iteration for non-concatenate iterators.""" + from replicate.use import OutputIterator + + # Create mock sync and async iterators for non-concatenate case + test_items = ["item1", "item2", "item3"] + + def sync_iterator(): + return iter(test_items) + + async def async_iterator(): + for item in test_items: + yield item + + # Test non-concatenate iterator + regular_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=False + ) + + # Test sync iteration + sync_result = list(regular_output) + assert sync_result == test_items + + # Test async iteration + async_result = [] + async for item in regular_output: + async_result.append(item) + assert async_result == test_items + + # Test sync string conversion + assert str(regular_output) == str(test_items) + + # Test async await (should return list for non-concatenate) + async_result = await regular_output + assert async_result == test_items + + +@pytest.mark.asyncio +@respx.mock +async def test_async_function_concatenate_iterator_output(): + """Test AsyncFunction with concatenate iterator output.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["Async", " ", "Hello", " ", "World"]} + ), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=True, streaming=True + ) + + run = await hotdog_detector.create(prompt="hello world") + output = await run.output() + + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) + assert str(output) == "Async Hello World" + + # Test async await (should return joined string for concatenate) + async_result = await output + assert async_result == "Async Hello World" + + # Test async iteration + async_result = [] + async for item in output: + async_result.append(item) + assert async_result == ["Async", " ", "Hello", " ", "World"] + + # Also test that it's still sync iterable + sync_result = list(output) + assert sync_result == ["Async", " ", "Hello", " ", "World"] + + +@pytest.mark.asyncio +async def test_output_iterator_await_syntax_demo(): + """Demonstrate the clean await syntax for OutputIterator.""" + from replicate.use import OutputIterator + + # Create mock iterators + def sync_iterator(): + return iter(["Hello", " ", "World"]) + + async def async_iterator(): + for item in ["Hello", " ", "World"]: + yield item + + # Test concatenate mode - await returns string + concatenate_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=True + ) + + # This is the clean syntax we wanted: str(await iterator) + result = await concatenate_output + assert result == "Hello World" + assert str(result) == "Hello World" # Can use str() on the result + + # Test non-concatenate mode - await returns list + regular_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=False + ) + + result = await regular_output + assert result == ["Hello", " ", "World"] + assert str(result) == "['Hello', ' ', 'World']" # str() gives list representation + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_concatenate_iterator_without_streaming_returns_string(client_mode): + """Test that concatenate iterator models without streaming=True return final concatenated string.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["Hello", " ", "world", "!"]} + ), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "Hello world!" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_iterator_output_returns_immediately(client_mode): + """Test that OutputIterator is returned immediately without waiting for completion.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": []}), + create_mock_prediction({"status": "processing", "output": ["Hello"]}), + create_mock_prediction( + {"status": "succeeded", "output": ["Hello", " ", "World"]} + ), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, + ) + + # Get the output iterator - this should return immediately even though prediction is processing + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + output_iterator = await run.output() + else: + run = hotdog_detector.create(prompt="hello world") + output_iterator = run.output() + + from replicate.use import OutputIterator + + assert isinstance(output_iterator, OutputIterator) + + # Verify the prediction is still processing when we get the iterator + assert run._prediction.status == "processing" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_streaming_output_yields_incrementally(client_mode): + """Test that OutputIterator yields results incrementally during polling.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + + # Create a prediction that will be polled multiple times + prediction_id = "pred123" + + initial_prediction = create_mock_prediction( + {"id": prediction_id, "status": "processing", "output": []}, + prediction_id=prediction_id, + ) + + if client_mode == ClientMode.ASYNC: + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response(201, json=initial_prediction) + ) + else: + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response(201, json=initial_prediction) + ) + + poll_responses = [ + create_mock_prediction( + {"status": "processing", "output": ["Hello"]}, prediction_id=prediction_id + ), + create_mock_prediction( + {"status": "processing", "output": ["Hello", " "]}, + prediction_id=prediction_id, + ), + create_mock_prediction( + {"status": "processing", "output": ["Hello", " ", "streaming"]}, + prediction_id=prediction_id, + ), + create_mock_prediction( + {"status": "processing", "output": ["Hello", " ", "streaming", " "]}, + prediction_id=prediction_id, + ), + create_mock_prediction( + { + "status": "succeeded", + "output": ["Hello", " ", "streaming", " ", "world!"], + }, + prediction_id=prediction_id, + ), + ] + + respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( + side_effect=[httpx.Response(200, json=resp) for resp in poll_responses] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, + ) + + # Get the output iterator immediately + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world", use_async=True) + output_iterator = await run.output() + else: + run = hotdog_detector.create(prompt="hello world") + output_iterator = run.output() + + from replicate.use import OutputIterator + + assert isinstance(output_iterator, OutputIterator) + + # Track when we receive each item to verify incremental delivery + collected_items = [] + + if client_mode == ClientMode.ASYNC: + async for item in output_iterator: + collected_items.append(item) + # Break after we get some incremental results to verify polling works + if len(collected_items) >= 3: + break + else: + for item in output_iterator: + collected_items.append(item) + # Break after we get some incremental results to verify polling works + if len(collected_items) >= 3: + break + + # Verify we got incremental streaming results + assert len(collected_items) >= 3 + # The items should be the concatenated string parts from the incremental output + result = "".join(collected_items) + assert "Hello" in result # Should contain the first part we streamed + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_non_streaming_output_waits_for_completion(client_mode): + """Test that non-iterator outputs still wait for completion.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": {"type": "string"} # Non-iterator output + } + } + } + } + ) + ] + ) + + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction({"status": "succeeded", "output": "Final result"}), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + # For non-iterator output, this should wait for completion + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + output = await run.output() + else: + run = hotdog_detector.create(prompt="hello world") + output = run.output() + + # Should get the final result directly + assert output == "Final result" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_list_of_strings_output(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["hello", "world", "test"]} + ), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == ["hello", "world", "test"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_iterator_of_strings_output(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["hello", "world", "test"]} + ), + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) + # Convert to list to check contents + output_list = list(output) + assert output_list == ["hello", "world", "test"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_path_output(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "string", + "format": "uri", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": "https://example.com/output.jpg"} + ), + ] + ) + + respx.get("https://example.com/output.jpg").mock( + return_value=httpx.Response(200, content=b"fake image data") + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert isinstance(output, os.PathLike) + assert get_path_url(output) == "https://example.com/output.jpg" + assert os.path.exists(output) + assert open(output, "rb").read() == b"fake image data" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_list_of_paths_output(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + { + "status": "succeeded", + "output": [ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ], + } + ), + ] + ) + + respx.get("https://example.com/output1.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 1 data") + ) + respx.get("https://example.com/output2.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 2 data") + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert isinstance(output, list) + assert len(output) == 2 + + assert all(isinstance(path, os.PathLike) for path in output) + assert get_path_url(output[0]) == "https://example.com/output1.jpg" + assert get_path_url(output[1]) == "https://example.com/output2.jpg" + + assert all(os.path.exists(path) for path in output) + assert open(output[0], "rb").read() == b"fake image 1 data" + assert open(output[1], "rb").read() == b"fake image 2 data" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_iterator_of_paths_output(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "x-cog-array-type": "iterator", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + { + "status": "succeeded", + "output": [ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ], + } + ), + ] + ) + + respx.get("https://example.com/output1.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 1 data") + ) + respx.get("https://example.com/output2.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 2 data") + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) + # Convert to list to check contents + output_list = list(output) + assert len(output_list) == 2 + assert all(isinstance(path, os.PathLike) for path in output_list) + assert get_path_url(output_list[0]) == "https://example.com/output1.jpg" + assert get_path_url(output_list[1]) == "https://example.com/output2.jpg" + assert all(os.path.exists(path) for path in output_list) + assert open(output_list[0], "rb").read() == b"fake image 1 data" + assert open(output_list[1], "rb").read() == b"fake image 2 data" + + +def test_get_path_url_with_urlpath(): + """Test get_path_url returns the URL for PathProxy instances.""" + from replicate.use import URLPath, get_path_url + + url = "https://example.com/test.jpg" + path_proxy = URLPath(url) + + result = get_path_url(path_proxy) + assert result == url + + +def test_get_path_url_with_regular_path(): + """Test get_path_url returns None for regular Path instances.""" + from replicate.use import get_path_url + + regular_path = Path("test.txt") + + result = get_path_url(regular_path) + assert result is None + + +def test_get_path_url_with_object_without_target(): + """Test get_path_url returns None for objects without __replicate_target__.""" + from replicate.use import get_path_url + + # Test with a string + result = get_path_url("not a path") + assert result is None + + # Test with a dict + result = get_path_url({"key": "value"}) + assert result is None + + # Test with None + result = get_path_url(None) + assert result is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_pathproxy_input_conversion(client_mode): + mock_model_endpoints() + + file_request_mock = respx.get("https://example.com/input.jpg").mock( + return_value=httpx.Response(200, content=b"fake input image data") + ) + + # Create a PathProxy instance + from replicate.use import URLPath + + urlpath = URLPath("https://example.com/input.jpg") + + # Set up a mock for the prediction creation to capture the request + request_body = None + + def capture_request(request): + nonlocal request_body + request_body = request.read() + return httpx.Response( + 201, + json={ + "id": "pred789", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred789", + "cancel": "https://api.replicate.com/v1/predictions/pred789/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"image": "https://example.com/input.jpg"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + respx.post("https://api.replicate.com/v1/predictions").mock( + side_effect=capture_request + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + await hotdog_detector.create(image=urlpath) + else: + hotdog_detector.create(image=urlpath) + + # Verify the request body contains the URL, not the downloaded file + assert request_body + parsed_body = json.loads(request_body) + assert parsed_body["input"]["image"] == "https://example.com/input.jpg" + + assert file_request_mock.call_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_function_logs_method(client_mode): + mock_model_endpoints() + mock_prediction_endpoints(predictions=[create_mock_prediction()]) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + else: + run = hotdog_detector.create(prompt="hello world") + + if client_mode == ClientMode.ASYNC: + logs = await run.logs() + else: + logs = run.logs() + + assert logs == "Starting prediction..." + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_function_logs_method_polling(client_mode): + mock_model_endpoints() + + polling_responses = [ + create_mock_prediction( + { + "logs": "Starting prediction...", + } + ), + create_mock_prediction( + { + "logs": "Starting prediction...\nProcessing input...", + } + ), + ] + + mock_prediction_endpoints(predictions=polling_responses) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + else: + run = hotdog_detector.create(prompt="hello world") + + if client_mode == ClientMode.ASYNC: + initial_logs = await run.logs() + else: + initial_logs = run.logs() + assert initial_logs == "Starting prediction..." + + if client_mode == ClientMode.ASYNC: + updated_logs = await run.logs() + else: + updated_logs = run.logs() + assert updated_logs == "Starting prediction...\nProcessing input..." + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_object_output_with_file_properties(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "image": { + "type": "string", + "format": "uri", + }, + "count": {"type": "integer"}, + }, + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + { + "status": "succeeded", + "output": { + "text": "Generated text", + "image": "https://example.com/generated.png", + "count": 42, + }, + } + ), + ] + ) + + respx.get("https://example.com/generated.png").mock( + return_value=httpx.Response(200, content=b"fake png data") + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert isinstance(output, dict) + assert output["text"] == "Generated text" + assert output["count"] == 42 + assert isinstance(output["image"], os.PathLike) + assert get_path_url(output["image"]) == "https://example.com/generated.png" + assert os.path.exists(output["image"]) + assert open(output["image"], "rb").read() == b"fake png data" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_object_output_with_file_list_property(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "images": { + "type": "array", + "items": { + "type": "string", + "format": "uri", + }, + }, + }, + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + { + "status": "succeeded", + "output": { + "text": "Generated text", + "images": [ + "https://example.com/image1.png", + "https://example.com/image2.png", + ], + }, + } + ), + ] + ) + + respx.get("https://example.com/image1.png").mock( + return_value=httpx.Response(200, content=b"fake png 1 data") + ) + respx.get("https://example.com/image2.png").mock( + return_value=httpx.Response(200, content=b"fake png 2 data") + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert isinstance(output, dict) + assert output["text"] == "Generated text" + assert isinstance(output["images"], list) + assert len(output["images"]) == 2 + assert all(isinstance(path, os.PathLike) for path in output["images"]) + assert get_path_url(output["images"][0]) == "https://example.com/image1.png" + assert get_path_url(output["images"][1]) == "https://example.com/image2.png" + assert all(os.path.exists(path) for path in output["images"]) + assert open(output["images"][0], "rb").read() == b"fake png 1 data" + assert open(output["images"][1], "rb").read() == b"fake png 2 data"