Skip to content

Commit 97d7aaa

Browse files
authored
Add deployment endpoints (#131)
Document replicate.deployments.predictions.create in README
1 parent c1e838a commit 97d7aaa

File tree

5 files changed

+114
-1
lines changed

5 files changed

+114
-1
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,23 @@ const response = await replicate.trainings.list();
552552
}
553553
```
554554

555+
### `replicate.deployments.predictions.create`
556+
557+
```js
558+
const response = await replicate.deployments.predictions.create(deployment_owner, deployment_name, options);
559+
```
560+
561+
| name | type | description |
562+
| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------- |
563+
| `deployment_owner` | string | **Required**. The name of the user or organization that owns the deployment |
564+
| `deployment_name` | string | **Required**. The name of the deployment |
565+
| `options.input` | object | **Required**. An object with the model's inputs |
566+
| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output |
567+
| `options.webhook_events_filter` | string[] | You can change which events trigger webhook requests by specifying webhook events (`start` \| `output` \| `logs` \| `completed`) |
568+
569+
Use `replicate.wait` to wait for a prediction to finish,
570+
or `replicate.predictions.cancel` to cancel a prediction before it finishes.
571+
555572
### `replicate.paginate`
556573

557574
Pass another method as an argument to iterate over results

index.d.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ declare module 'replicate' {
4747
logs?: string;
4848
metrics?: {
4949
predict_time?: number;
50-
}
50+
};
5151
webhook?: string;
5252
webhook_events_filter?: WebhookEventType[];
5353
created_at: string;
@@ -156,5 +156,20 @@ declare module 'replicate' {
156156
cancel(training_id: string): Promise<Training>;
157157
list(): Promise<Page<Training>>;
158158
};
159+
160+
deployments: {
161+
predictions: {
162+
create(
163+
deployment_name: string,
164+
deployment_owner: string,
165+
options: {
166+
input: object;
167+
stream?: boolean;
168+
webhook?: string;
169+
webhook_events_filter?: WebhookEventType[];
170+
}
171+
): Promise<Prediction>;
172+
};
173+
};
159174
}
160175
}

index.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ const ApiError = require('./lib/error');
22
const { withAutomaticRetries } = require('./lib/util');
33

44
const collections = require('./lib/collections');
5+
const deployments = require('./lib/deployments');
56
const models = require('./lib/models');
67
const predictions = require('./lib/predictions');
78
const trainings = require('./lib/trainings');
@@ -69,6 +70,12 @@ class Replicate {
6970
cancel: trainings.cancel.bind(this),
7071
list: trainings.list.bind(this),
7172
};
73+
74+
this.deployments = {
75+
predictions: {
76+
create: deployments.predictions.create.bind(this),
77+
}
78+
};
7279
}
7380

7481
/**

index.test.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,43 @@ describe('Replicate client', () => {
582582
});
583583
});
584584

585+
describe('deployments.predictions.create', () => {
586+
test('Calls the correct API route with the correct payload', async () => {
587+
nock(BASE_URL)
588+
.post('/deployments/replicate/greeter/predictions')
589+
.reply(200, {
590+
id: 'mfrgcyzzme2wkmbwgzrgmntcg',
591+
version:
592+
'5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
593+
urls: {
594+
get: 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq',
595+
cancel:
596+
'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel',
597+
},
598+
created_at: '2022-09-10T09:44:22.165836Z',
599+
started_at: null,
600+
completed_at: null,
601+
status: 'starting',
602+
input: {
603+
text: 'Alice',
604+
},
605+
output: null,
606+
error: null,
607+
logs: null,
608+
metrics: {},
609+
});
610+
const prediction = await client.deployments.predictions.create("replicate", "greeter", {
611+
input: {
612+
text: 'Alice',
613+
},
614+
webhook: 'http://test.host/webhook',
615+
webhook_events_filter: [ 'output', 'completed' ],
616+
});
617+
expect(prediction.id).toBe('mfrgcyzzme2wkmbwgzrgmntcg');
618+
});
619+
// Add more tests for error handling, edge cases, etc.
620+
});
621+
585622
describe('run', () => {
586623
test('Calls the correct API routes', async () => {
587624
let firstPollingRequest = true;

lib/deployments.js

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* Create a new prediction with a deployment
3+
*
4+
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
5+
* @param {string} deployment_name - Required. The name of the deployment
6+
* @param {object} options
7+
* @param {object} options.input - Required. An object with the model inputs
8+
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
9+
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
10+
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
11+
* @returns {Promise<object>} Resolves with the created prediction data
12+
*/
13+
async function createPrediction(deployment_owner, deployment_name, options) {
14+
const { stream, ...data } = options;
15+
16+
if (data.webhook) {
17+
try {
18+
// eslint-disable-next-line no-new
19+
new URL(data.webhook);
20+
} catch (err) {
21+
throw new Error('Invalid webhook URL');
22+
}
23+
}
24+
25+
const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}/predictions`, {
26+
method: 'POST',
27+
data: { ...data, stream },
28+
});
29+
30+
return response.json();
31+
}
32+
33+
module.exports = {
34+
predictions: {
35+
create: createPrediction,
36+
}
37+
};

0 commit comments

Comments
 (0)