Skip to content

Commit e8acdb2

Browse files
committed
Use new URLPath instead of PathProxy
1 parent 4111e82 commit e8acdb2

File tree

4 files changed

+125
-111
lines changed

4 files changed

+125
-111
lines changed

README.md

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,9 @@ To use a model:
525525
> For now `use()` MUST be called in the top level module scope. We may relax this in future.
526526
527527
```py
528-
from replicate import use
528+
import replicate
529529

530-
flux_dev = use("black-forest-labs/flux-dev")
530+
flux_dev = replicate.use("black-forest-labs/flux-dev")
531531
outputs = flux_dev(prompt="a cat wearing an amusing hat")
532532

533533
for output in outputs:
@@ -538,7 +538,7 @@ Models that output iterators will return iterators:
538538

539539

540540
```py
541-
claude = use("anthropic/claude-4-sonnet")
541+
claude = replicate.use("anthropic/claude-4-sonnet")
542542

543543
output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.")
544544

@@ -555,10 +555,10 @@ str(output) # "Here's a recipe to feed all of California (about 39 million peopl
555555
You can pass the results of one model directly into another:
556556

557557
```py
558-
from replicate import use
558+
import replicate
559559

560-
flux_dev = use("black-forest-labs/flux-dev")
561-
claude = use("anthropic/claude-4-sonnet")
560+
flux_dev = replicate.use("black-forest-labs/flux-dev")
561+
claude = replicate.use("anthropic/claude-4-sonnet")
562562

563563
images = flux_dev(prompt="a cat wearing an amusing hat")
564564

@@ -570,7 +570,7 @@ print(str(result)) # "This shows an image of a cat wearing a hat ..."
570570
To create an individual prediction that has not yet resolved, use the `create()` method:
571571

572572
```
573-
claude = use("anthropic/claude-4-sonnet")
573+
claude = replicate.use("anthropic/claude-4-sonnet")
574574
575575
prediction = claude.create(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.")
576576
@@ -579,13 +579,49 @@ prediction.logs() # get current logs (WIP)
579579
prediction.output() # get the output
580580
```
581581

582-
You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper.
582+
### Downloading file outputs
583+
584+
Output files are provided as Python [os.PathLike](https://docs.python.org/3.12/library/os.html#os.PathLike) objects. These are supported by most of the Python standard library like `open()` and `Path`, as well as third-party libraries like `pillow` and `ffmpeg-python`.
585+
586+
The first time the file is accessed it will be downloaded to a temporary directory on disk ready for use.
587+
588+
Here's an example of how to use the `pillow` package to convert file outputs:
583589

584590
```py
585-
from replicate import use
586-
from replicate.use import get_url_path
591+
import replicate
592+
from PIL import Image
593+
594+
flux_dev = replicate.use("black-forest-labs/flux-dev")
595+
596+
images = flux_dev(prompt="a cat wearing an amusing hat")
597+
for i, path in enumerate(images):
598+
with Image.open(path) as img:
599+
img.save(f"./output_{i}.png", format="PNG")
600+
```
601+
602+
For libraries that do not support `Path` or `PathLike` instances you can use `open()` as you would with any other file. For example to use `requests` to upload the file to a different location:
603+
604+
```py
605+
import replicate
606+
import requests
607+
608+
flux_dev = replicate.use("black-forest-labs/flux-dev")
609+
610+
images = flux_dev(prompt="a cat wearing an amusing hat")
611+
for path in images:
612+
with open(path, "rb") as f:
613+
r = requests.post("https://api.example.com/upload", files={"file": f})
614+
```
615+
616+
### Accessing outputs as HTTPS URLs
617+
618+
If you do not need to download the output to disk. You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper.
619+
620+
```py
621+
import replicate
622+
from replicate import get_url_path
587623

588-
flux_dev = use("black-forest-labs/flux-dev")
624+
flux_dev = replicate.use("black-forest-labs/flux-dev")
589625
outputs = flux_dev(prompt="a cat wearing an amusing hat")
590626

591627
for output in outputs:

replicate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from replicate.client import Client
22
from replicate.pagination import async_paginate as _async_paginate
33
from replicate.pagination import paginate as _paginate
4-
from replicate.use import use
4+
from replicate.use import get_path_url, use
55

66
__all__ = [
77
"Client",
@@ -21,6 +21,7 @@
2121
"trainings",
2222
"webhooks",
2323
"default_client",
24+
"get_path_url",
2425
]
2526

2627
default_client = Client()

replicate/use.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# - [ ] Support text streaming
33
# - [ ] Support file streaming
44
# - [ ] Support asyncio variant
5+
import hashlib
56
import inspect
67
import os
78
import sys
@@ -138,7 +139,7 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
138139
# If items are file URLs, download them
139140
if items_schema.get("type") == "string" and items_schema.get("format") == "uri":
140141
if isinstance(item, str) and item.startswith(("http://", "https://")):
141-
return PathProxy(item)
142+
return URLPath(item)
142143

143144
return item
144145

@@ -154,7 +155,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
154155
# Handle direct string with format=uri
155156
if output_schema.get("type") == "string" and output_schema.get("format") == "uri":
156157
if isinstance(output, str) and output.startswith(("http://", "https://")):
157-
return PathProxy(output)
158+
return URLPath(output)
158159
return output
159160

160161
# Handle array of strings with format=uri
@@ -163,7 +164,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
163164
if items.get("type") == "string" and items.get("format") == "uri":
164165
if isinstance(output, list):
165166
return [
166-
PathProxy(url)
167+
URLPath(url)
167168
if isinstance(url, str) and url.startswith(("http://", "https://"))
168169
else url
169170
for url in output
@@ -187,15 +188,15 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
187188
if isinstance(value, str) and value.startswith(
188189
("http://", "https://")
189190
):
190-
result[prop_name] = PathProxy(value)
191+
result[prop_name] = URLPath(value)
191192

192193
# Array of files property
193194
elif prop_schema.get("type") == "array":
194195
items = prop_schema.get("items", {})
195196
if items.get("type") == "string" and items.get("format") == "uri":
196197
if isinstance(value, list):
197198
result[prop_name] = [
198-
PathProxy(url)
199+
URLPath(url)
199200
if isinstance(url, str)
200201
and url.startswith(("http://", "https://"))
201202
else url
@@ -233,49 +234,52 @@ def __str__(self) -> str:
233234
return str(self.iterator_factory())
234235

235236

236-
class PathProxy(Path):
237-
def __init__(self, target: str) -> None:
238-
path: Path | None = None
239-
240-
def ensure_path() -> Path:
241-
nonlocal path
242-
if path is None:
243-
path = _download_file(target)
244-
return path
237+
class URLPath(os.PathLike):
238+
"""
239+
A PathLike that defers filesystem ops until first use. Can be used with
240+
most Python file interfaces like `open()` and `pathlib.Path()`.
241+
See: https://docs.python.org/3.12/library/os.html#os.PathLike
242+
"""
245243

246-
object.__setattr__(self, "__replicate_target__", target)
247-
object.__setattr__(self, "__replicate_path__", ensure_path)
244+
def __init__(self, url: str) -> None:
245+
# store the original URL
246+
self.__url__ = url
248247

249-
def __getattribute__(self, name) -> Any:
250-
if name in ("__replicate_path__", "__replicate_target__"):
251-
return object.__getattribute__(self, name)
248+
# compute target path without touching the filesystem
249+
base = Path(tempfile.gettempdir())
250+
h = hashlib.sha256(self.__url__.encode("utf-8")).hexdigest()[:16]
251+
name = Path(httpx.URL(self.__url__).path).name or h
252+
self.__path__ = base / h / name
252253

253-
# TODO: We should cover other common properties on Path...
254-
if name == "__class__":
255-
return Path
254+
def __fspath__(self) -> str:
255+
# on first access, create dirs and download if missing
256+
if not self.__path__.exists():
257+
subdir = self.__path__.parent
258+
subdir.mkdir(parents=True, exist_ok=True)
259+
if not os.access(subdir, os.W_OK):
260+
raise PermissionError(f"Cannot write to {subdir!r}")
256261

257-
return getattr(object.__getattribute__(self, "__replicate_path__")(), name)
262+
with httpx.Client() as client, client.stream("GET", self.__url__) as resp:
263+
resp.raise_for_status()
264+
with open(self.__path__, "wb") as f:
265+
for chunk in resp.iter_bytes(chunk_size=16_384):
266+
f.write(chunk)
258267

259-
def __setattr__(self, name, value) -> None:
260-
if name in ("__replicate_path__", "__replicate_target__"):
261-
raise ValueError()
268+
return str(self.__path__)
262269

263-
object.__setattr__(
264-
object.__getattribute__(self, "__replicate_path__")(), name, value
265-
)
270+
def __str__(self) -> str:
271+
return str(self.__path__)
266272

267-
def __delattr__(self, name) -> None:
268-
if name in ("__replicate_path__", "__replicate_target__"):
269-
raise ValueError()
270-
delattr(object.__getattribute__(self, "__replicate_path__")(), name)
273+
def __repr__(self) -> str:
274+
return f"<URLPath url={self.__url__!r} path={self.__path__!r}>"
271275

272276

273277
def get_path_url(path: Any) -> str | None:
274278
"""
275279
Return the remote URL (if any) for a Path output from a model.
276280
"""
277281
try:
278-
return object.__getattribute__(path, "__replicate_target__")
282+
return object.__getattribute__(path, "__url__")
279283
except AttributeError:
280284
return None
281285

@@ -385,7 +389,7 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
385389
"""
386390
Start a prediction with the specified inputs.
387391
"""
388-
# Process inputs to convert concatenate OutputIterators to strings and PathProxy to URLs
392+
# Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
389393
processed_inputs = {}
390394
for key, value in inputs.items():
391395
if isinstance(value, OutputIterator) and value.is_concatenate:

0 commit comments

Comments
 (0)