Skip to content

Commit 3aaa16e

Browse files
authored
Add support for collections.list endpoint (#190)
See https://replicate.com/docs/reference/http#collections.list --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 2f84638 commit 3aaa16e

File tree

8 files changed

+18717
-5
lines changed

8 files changed

+18717
-5
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,38 @@ urlretrieve(out[0], "/tmp/out.png")
189189
background = Image.open("/tmp/out.png")
190190
```
191191

192+
## List models
193+
194+
You can the models you've created:
195+
196+
```python
197+
replicate.models.list()
198+
```
199+
200+
Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method. Here's how you can get all the models you've created:
201+
202+
```python
203+
models = []
204+
page = replicate.models.list()
205+
206+
while page:
207+
models.extend(page.results)
208+
page = replicate.models.list(page.next) if page.next else None
209+
```
210+
211+
You can also find collections of featured models on Replicate:
212+
213+
```python
214+
>>> collections = replicate.collections.list()
215+
>>> collections[0].slug
216+
"vision-models"
217+
>>> collections[0].description
218+
"Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)"
219+
220+
>>> replicate.collections.get("text-to-image").models
221+
[<Model: stability-ai/sdxl>, ...]
222+
```
223+
192224
## Create a model
193225

194226
You can create a model for a user or organization

replicate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
default_client = Client()
44
run = default_client.run
5+
collections = default_client.collections
56
hardware = default_client.hardware
67
deployments = default_client.deployments
78
models = default_client.models

replicate/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import httpx
1616

1717
from replicate.__about__ import __version__
18+
from replicate.collection import Collections
1819
from replicate.deployment import Deployments
1920
from replicate.exceptions import ModelError, ReplicateError
2021
from replicate.hardware import Hardwares
@@ -84,6 +85,13 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
8485

8586
return resp
8687

88+
@property
89+
def collections(self) -> Collections:
90+
"""
91+
Namespace for operations related to collections of models.
92+
"""
93+
return Collections(client=self)
94+
8795
@property
8896
def deployments(self) -> Deployments:
8997
"""

replicate/collection.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
2+
3+
from replicate.model import Model, Models
4+
from replicate.pagination import Page
5+
from replicate.resource import Namespace, Resource
6+
7+
if TYPE_CHECKING:
8+
from replicate.client import Client
9+
10+
11+
class Collection(Resource):
12+
"""
13+
A collection of models on Replicate.
14+
"""
15+
16+
slug: str
17+
"""The slug used to identify the collection."""
18+
19+
name: str
20+
"""The name of the collection."""
21+
22+
description: str
23+
"""A description of the collection."""
24+
25+
models: Optional[List[Model]] = None
26+
"""The models in the collection."""
27+
28+
def __iter__(self): # noqa: ANN204
29+
return iter(self.models)
30+
31+
def __getitem__(self, index) -> Optional[Model]:
32+
if self.models is not None:
33+
return self.models[index]
34+
35+
return None
36+
37+
def __len__(self) -> int:
38+
if self.models is not None:
39+
return len(self.models)
40+
41+
return 0
42+
43+
44+
class Collections(Namespace):
45+
"""
46+
A namespace for operations related to collections of models.
47+
"""
48+
49+
model = Collection
50+
51+
_models: Models
52+
53+
def __init__(self, client: "Client") -> None:
54+
self._models = Models(client)
55+
super().__init__(client)
56+
57+
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Collection]: # noqa: F821
58+
"""
59+
List collections of models.
60+
61+
Parameters:
62+
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
63+
Returns:
64+
Page[Collection]: A page of of model collections.
65+
Raises:
66+
ValueError: If `cursor` is `None`.
67+
"""
68+
69+
if cursor is None:
70+
raise ValueError("cursor cannot be None")
71+
72+
resp = self._client._request(
73+
"GET", "/v1/collections" if cursor is ... else cursor
74+
)
75+
76+
return Page[Collection](self._client, self, **resp.json())
77+
78+
def get(self, slug: str) -> Collection:
79+
"""Get a model by name.
80+
81+
Args:
82+
name: The name of the model, in the format `owner/model-name`.
83+
Returns:
84+
The model.
85+
"""
86+
87+
resp = self._client._request("GET", f"/v1/collections/{slug}")
88+
89+
return self._prepare_model(resp.json())
90+
91+
def _prepare_model(self, attrs: Union[Collection, Dict]) -> Collection:
92+
if isinstance(attrs, Resource):
93+
attrs.id = attrs.slug
94+
95+
if attrs.models is not None:
96+
attrs.models = [self._models._prepare_model(m) for m in attrs.models]
97+
elif isinstance(attrs, dict):
98+
attrs["id"] = attrs["slug"]
99+
100+
if "models" in attrs:
101+
attrs["models"] = [
102+
self._models._prepare_model(m) for m in attrs["models"]
103+
]
104+
105+
return super()._prepare_model(attrs)

replicate/hardware.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ def list(self) -> List[Hardware]:
3535
"""
3636

3737
resp = self._client._request("GET", "/v1/hardware")
38-
hardware = resp.json()
39-
return [self._prepare_model(obj) for obj in hardware]
38+
return [self._prepare_model(obj) for obj in resp.json()]
4039

4140
def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
4241
if isinstance(attrs, Resource):
4342
attrs.id = attrs.sku
4443
elif isinstance(attrs, dict):
4544
attrs["id"] = attrs["sku"]
4645

47-
hardware = super()._prepare_model(attrs)
48-
49-
return hardware
46+
return super()._prepare_model(attrs)

tests/cassettes/collections-get.yaml

Lines changed: 18285 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)