Skip to content

Commit d4474a4

Browse files
erbridgezeke
andauthored
Wrap required arguments up together (#20)
For now we only have `input`, but this allows us to change the API in future and matches the API more directly. Co-authored-by: Zeke Sikelianos <[email protected]>
1 parent 96b8e5e commit d4474a4

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ const prediction = await replicate
2626
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
2727
)
2828
.predict({
29-
prompt: "an astronaut riding on a horse",
29+
input: {
30+
prompt: "an astronaut riding on a horse",
31+
},
3032
});
3133

3234
console.log(prediction.output);
@@ -45,7 +47,9 @@ await replicate
4547
)
4648
.predict(
4749
{
48-
prompt: "an astronaut riding on a horse",
50+
input: {
51+
prompt: "an astronaut riding on a horse",
52+
},
4953
},
5054
{
5155
onUpdate: (prediction) => {
@@ -66,7 +70,9 @@ const prediction = await replicate
6670
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
6771
)
6872
.createPrediction({
69-
prompt: "an astronaut riding on a horse",
73+
input: {
74+
prompt: "an astronaut riding on a horse",
75+
},
7076
});
7177

7278
console.log(prediction.status); // "starting"
@@ -89,7 +95,9 @@ await replicate
8995
)
9096
.createPrediction(
9197
{
92-
prompt: "an astronaut riding on a horse",
98+
input: {
99+
prompt: "an astronaut riding on a horse",
100+
},
93101
},
94102
{
95103
// See https://replicate.com/docs/reference/http#create-prediction--webhook

lib/Model.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export default class Model extends ReplicateObject {
4747
}
4848

4949
async predict(
50-
input,
50+
{ input },
5151
{
5252
onUpdate = noop,
5353
onTemporaryError = noop,
@@ -122,12 +122,16 @@ export default class Model extends ReplicateObject {
122122
return prediction;
123123
}
124124

125-
async createPrediction(input, { webhook, webhookEventsFilter } = {}) {
125+
async createPrediction({ input }, { webhook, webhookEventsFilter } = {}) {
126126
// This is here and not on `Prediction` because conceptually, a prediction
127127
// from a model "belongs" to the model. It's an odd feature of the API that
128128
// the prediction creation isn't an action on the model (or that it doesn't
129129
// actually use the model information, only the version), but we don't need
130130
// to expose that to users of this library.
131+
if (!input) {
132+
throw new ReplicateError("input is required");
133+
}
134+
131135
const predictionData = await this.client.request("POST /v1/predictions", {
132136
version: this.version,
133137
input,

lib/Model.test.js

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ describe("predict()", () => {
8686
);
8787

8888
await model.predict(
89-
{ text: "test text" },
89+
{ input: { text: "test text" } },
9090
{},
9191
{ defaultPollingInterval: 0 }
9292
);
@@ -128,7 +128,7 @@ describe("predict()", () => {
128128
.mockImplementation((action) => requestMockReturnValues[action]);
129129

130130
await model.predict(
131-
{ text: "test text" },
131+
{ input: { text: "test text" } },
132132
{},
133133
{ defaultPollingInterval: 0 }
134134
);
@@ -182,7 +182,7 @@ describe("predict()", () => {
182182
});
183183

184184
const prediction = await model.predict(
185-
{ text: "test text" },
185+
{ input: { text: "test text" } },
186186
{},
187187
{ defaultPollingInterval: 0 }
188188
);
@@ -237,7 +237,7 @@ describe("predict()", () => {
237237
const backoffFn = jest.fn(() => 0);
238238

239239
const prediction = await model.predict(
240-
{ text: "test text" },
240+
{ input: { text: "test text" } },
241241
{},
242242
{ defaultPollingInterval: 0, backoffFn }
243243
);
@@ -255,7 +255,7 @@ describe("createPrediction()", () => {
255255
status: PredictionStatus.SUCCEEDED,
256256
});
257257

258-
await model.createPrediction({ text: "test text" });
258+
await model.createPrediction({ input: { text: "test text" } });
259259

260260
expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", {
261261
version: "testversion",
@@ -270,7 +270,7 @@ describe("createPrediction()", () => {
270270
});
271271

272272
await model.createPrediction(
273-
{ text: "test text" },
273+
{ input: { text: "test text" } },
274274
{ webhook: "http://test.host/webhook" }
275275
);
276276

@@ -288,7 +288,7 @@ describe("createPrediction()", () => {
288288
});
289289

290290
await model.createPrediction(
291-
{ text: "test text" },
291+
{ input: { text: "test text" } },
292292
{
293293
webhook: "http://test.host/webhook",
294294
webhookEventsFilter: ["output", "completed"],

0 commit comments

Comments
 (0)