Skip to content

Generate TypeScript definitions from source #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ jobs:
- name: Build tarball
id: pack
run: |
npm clean-install
echo "tarball-name=$(npm --loglevel error pack)" >> $GITHUB_OUTPUT
- uses: actions/upload-artifact@v3
with:
315 changes: 0 additions & 315 deletions index.d.ts

This file was deleted.

76 changes: 65 additions & 11 deletions index.js
Original file line number Diff line number Diff line change
@@ -41,30 +41,54 @@ class Replicate {
/**
* Create a new Replicate API client instance.
*
* @param {object} options - Configuration options for the client
* @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
* @param {string} options.userAgent - Identifier of your app
* @example
* // Create a new Replicate API client instance
* const Replicate = require("replicate");
* const replicate = new Replicate({
* // get your token from https://replicate.com/account
* auth: process.env.REPLICATE_API_TOKEN,
* userAgent: "my-app/1.2.3"
* });
*
* // Run a model and await the result:
* const model = 'owner/model:version-id'
* const input = {text: 'Hello, world!'}
* const output = await replicate.run(model, { input });
*
* @param {Object} [options] - Configuration options for the client
* @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
* @param {string} [options.userAgent] - Identifier of your app
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
*/
constructor(options = {}) {
/** @type {string} */
this.auth =
options.auth ||
(typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null);

/** @type {string} */
this.userAgent =
options.userAgent || `replicate-javascript/${packageJSON.version}`;

/** @type {string} */
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";

/** @type {fetch} */
this.fetch = options.fetch || globalThis.fetch;

/** @type {accounts} */
this.accounts = {
current: accounts.current.bind(this),
};

/** @type {collections} */
this.collections = {
list: collections.list.bind(this),
get: collections.get.bind(this),
};

/** @type {deployments} */
this.deployments = {
get: deployments.get.bind(this),
create: deployments.create.bind(this),
@@ -75,10 +99,12 @@ class Replicate {
},
};

/** @type {hardware} */
this.hardware = {
list: hardware.list.bind(this),
};

/** @type {models} */
this.models = {
get: models.get.bind(this),
list: models.list.bind(this),
@@ -89,20 +115,23 @@ class Replicate {
},
};

/** @type {predictions} */
this.predictions = {
create: predictions.create.bind(this),
get: predictions.get.bind(this),
cancel: predictions.cancel.bind(this),
list: predictions.list.bind(this),
};

/** @type {trainings} */
this.trainings = {
create: trainings.create.bind(this),
get: trainings.get.bind(this),
cancel: trainings.cancel.bind(this),
list: trainings.list.bind(this),
};

/** @type {webhooks} */
this.webhooks = {
default: {
secret: {
@@ -115,18 +144,18 @@ class Replicate {
/**
* Run a model and wait for its output.
*
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {`${string}/${string}` | `${string}/${string}:${string}`} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {object} [options.wait] - Options for waiting for the prediction to finish
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
* @param {(p: Prediction) => void} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
* @throws {Error} If the reference is invalid
* @throws {Error} If the prediction failed
* @returns {Promise<object>} - Resolves with the output of running the model
* @returns {Promise<Prediction>} - Resolves with the output of running the model
*/
async run(ref, options, progress) {
const { wait, ...data } = options;
@@ -262,7 +291,7 @@ class Replicate {
/**
* Stream a model and wait for its output.
*
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {string} ref - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
@@ -315,8 +344,10 @@ class Replicate {
* for await (const page of replicate.paginate(replicate.predictions.list) {
* console.log(page);
* }
* @param {Function} endpoint - Function that returns a promise for the next page of results
* @yields {object[]} Each page of results
* @template T
* @param {() => Promise<Page<T>>} endpoint - Function that returns a promise for the next page of results
* @yields {T[]} Each page of results
* @returns {AsyncGenerator<T[], void, unknown>}
*/
async *paginate(endpoint) {
const response = await endpoint();
@@ -342,7 +373,7 @@ class Replicate {
* @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling.
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
* @throws {Error} If the prediction failed
* @returns {Promise<object>} Resolves with the completed prediction object
* @returns {Promise<Prediction>} Resolves with the completed prediction object
*/
async wait(prediction, options, stop) {
const { id } = prediction;
@@ -391,3 +422,26 @@ class Replicate {
module.exports = Replicate;
module.exports.validateWebhook = validateWebhook;
module.exports.parseProgressFromLogs = parseProgressFromLogs;

// - Type Definitions

/**
* @typedef {import("./lib/error")} ApiError
* @typedef {import("./lib/types").Account} Account
* @typedef {import("./lib/types").Collection} Collection
* @typedef {import("./lib/types").Deployment} Deployment
* @typedef {import("./lib/types").ModelVersion} ModelVersion
* @typedef {import("./lib/types").Hardware} Hardware
* @typedef {import("./lib/types").Model} Model
* @typedef {import("./lib/types").Prediction} Prediction
* @typedef {import("./lib/types").Training} Training
* @typedef {import("./lib/types").ServerSentEvent} ServerSentEvent
* @typedef {import("./lib/types").Status} Status
* @typedef {import("./lib/types").Visibility} Visibility
* @typedef {import("./lib/types").WebhookEventType} WebhookEventType
*/

/**
* @template T
* @typedef {import("./lib/types").Page<T>} Page
*/
31 changes: 10 additions & 21 deletions index.test.ts
Original file line number Diff line number Diff line change
@@ -5,9 +5,8 @@ import Replicate, {
Prediction,
validateWebhook,
parseProgressFromLogs,
} from "replicate";
} from "./";
import nock from "nock";
import { Readable } from "node:stream";
import { createReadableStream } from "./lib/stream";

let client: Replicate;
@@ -791,10 +790,8 @@ describe("Replicate client", () => {
},
configuration: {
hardware: "gpu-t4",
scaling: {
min_instances: 1,
max_instances: 5,
},
min_instances: 1,
max_instances: 5,
},
},
});
@@ -832,10 +829,8 @@ describe("Replicate client", () => {
},
configuration: {
hardware: "gpu-t4",
scaling: {
min_instances: 1,
max_instances: 5,
},
min_instances: 1,
max_instances: 5,
},
},
});
@@ -878,10 +873,8 @@ describe("Replicate client", () => {
},
configuration: {
hardware: "gpu-a40-large",
scaling: {
min_instances: 3,
max_instances: 10,
},
min_instances: 3,
max_instances: 10,
},
},
});
@@ -905,12 +898,8 @@ describe("Replicate client", () => {
expect(deployment.current_release.configuration.hardware).toBe(
"gpu-a40-large"
);
expect(
deployment.current_release.configuration.scaling?.min_instances
).toBe(3);
expect(
deployment.current_release.configuration.scaling?.max_instances
).toBe(10);
expect(deployment.current_release.configuration.min_instances).toBe(3);
expect(deployment.current_release.configuration.max_instances).toBe(10);
});
// Add more tests for error handling, edge cases, etc.
});
@@ -935,7 +924,7 @@ describe("Replicate client", () => {
});

const deployments = await client.deployments.list();
expect(deployments.results.length).toBe(1)
expect(deployments.results.length).toBe(1);
});
// Add more tests for pagination, error handling, edge cases, etc.
});
135 changes: 135 additions & 0 deletions integration/typescript/types.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import {
Account,
ApiError,
Collection,
Deployment,
Hardware,
Model,
ModelVersion,
Page,
Prediction,
Status,
Training,
Visibility,
WebhookEventType,
} from "replicate";

export type Equals<X, Y> = (<T>() => T extends X ? 1 : 2) extends <
T,
>() => T extends Y ? 1 : 2
? true
: false;

type AssertFalse<A extends false> = A;

// @ts-expect-error
export type TestAssertion = AssertFalse<Equals<any, any>>;

export type TestAccount = AssertFalse<Equals<Account, any>>;
export type TestApiError = AssertFalse<Equals<ApiError, any>>;
export type TestCollection = AssertFalse<Equals<Collection, any>>;
export type TestDeployment = AssertFalse<Equals<Deployment, any>>;
export type TestHardware = AssertFalse<Equals<Hardware, any>>;
export type TestModel = AssertFalse<Equals<Model, any>>;
export type TestModelVersion = AssertFalse<Equals<ModelVersion, any>>;
export type TestPage = AssertFalse<Equals<Page<unknown>, any>>;
export type TestPrediction = AssertFalse<Equals<Prediction, any>>;
export type TestStatus = AssertFalse<Equals<Status, any>>;
export type TestTraining = AssertFalse<Equals<Training, any>>;
export type TestVisibility = AssertFalse<Equals<Visibility, any>>;
export type TestWebhookEventType = AssertFalse<Equals<WebhookEventType, any>>;

// NOTE: We export the constants to avoid unused varaible issues.

export const account: Account = {
type: "user",
name: "",
username: "",
github_url: "",
};
export const collection: Collection = {
name: "",
slug: "",
description: "",
models: [],
};
export const deployment: Deployment = {
owner: "",
name: "",
current_release: {
number: 1,
model: "",
version: "",
created_at: "",
created_by: {
type: "user",
username: "",
name: "",
github_url: "",
},
configuration: {
hardware: "gpu-a100",
min_instances: 0,
max_instances: 5,
},
},
};
export const status: Status = "starting";
export const visibility: Visibility = "public";
export const webhookType: WebhookEventType = "start";
export const err: ApiError = Object.assign(new Error(), {
request: new Request("file://"),
response: new Response(),
});
export const hardware: Hardware = { sku: "", name: "" };
export const model: Model = {
url: "",
owner: "",
name: "",
description: "",
visibility: "public",
github_url: "",
paper_url: "",
license_url: "",
run_count: 10,
cover_image_url: "",
default_example: undefined,
latest_version: undefined,
};
export const version: ModelVersion = {
id: "",
created_at: "",
cog_version: "",
openapi_schema: "",
};
export const prediction: Prediction = {
id: "",
status: "starting",
model: "",
version: "",
input: {},
output: {},
source: "api",
error: undefined,
logs: "",
metrics: {
predict_time: 100,
},
webhook: "",
webhook_events_filter: [],
created_at: "",
started_at: "",
completed_at: "",
urls: {
get: "",
cancel: "",
stream: "",
},
};
export const training: Training = prediction;

export const page: Page<ModelVersion> = {
previous: "",
next: "",
results: [version],
};
4 changes: 3 additions & 1 deletion jsconfig.json
Original file line number Diff line number Diff line change
@@ -6,9 +6,11 @@
"target": "ES2020",
"resolveJsonModule": true,
"strictNullChecks": true,
"strictFunctionTypes": true
"strictFunctionTypes": true,
"types": [],
},
"exclude": [
"dist",
"node_modules",
"**/node_modules/*"
]
4 changes: 3 additions & 1 deletion lib/accounts.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
/** @typedef {import("./types").Account} Account */

/**
* Get the current account
*
* @returns {Promise<object>} Resolves with the current account
* @returns {Promise<Account>} Resolves with the current account
*/
async function getCurrentAccount() {
const response = await this.request("/account", {
10 changes: 8 additions & 2 deletions lib/collections.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
/** @typedef {import("./types").Collection} Collection */
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/

/**
* Fetch a model collection
*
* @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections
* @returns {Promise<object>} - Resolves with the collection data
* @returns {Promise<Collection>} - Resolves with the collection data
*/
async function getCollection(collection_slug) {
const response = await this.request(`/collections/${collection_slug}`, {
@@ -15,7 +21,7 @@ async function getCollection(collection_slug) {
/**
* Fetch a list of model collections
*
* @returns {Promise<object>} - Resolves with the collections data
* @returns {Promise<Page<Collection>>} - Resolves with the collections data
*/
async function listCollections() {
const response = await this.request("/collections", {
37 changes: 22 additions & 15 deletions lib/deployments.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/
/** @typedef {import("./types").Deployment} Deployment */
/** @typedef {import("./types").Prediction} Prediction */
/** @typedef {import("./types").WebhookEventType} WebhookEventType */

const { transformFileInputs } = require("./util");

/**
@@ -6,18 +14,17 @@ const { transformFileInputs } = require("./util");
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {unknown} options.input - Required. An object with the model inputs
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @returns {Promise<object>} Resolves with the created prediction data
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @returns {Promise<Prediction>} Resolves with the created prediction data
*/
async function createPrediction(deployment_owner, deployment_name, options) {
const { stream, input, ...data } = options;

if (data.webhook) {
try {
// eslint-disable-next-line no-new
new URL(data.webhook);
} catch (err) {
throw new Error("Invalid webhook URL");
@@ -44,7 +51,7 @@ async function createPrediction(deployment_owner, deployment_name, options) {
*
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @returns {Promise<object>} Resolves with the deployment data
* @returns {Promise<Deployment>} Resolves with the deployment data
*/
async function getDeployment(deployment_owner, deployment_name) {
const response = await this.request(
@@ -71,32 +78,32 @@ async function getDeployment(deployment_owner, deployment_name) {
* Create a deployment
*
* @param {DeploymentCreateRequest} config - Required. The deployment config.
* @returns {Promise<object>} Resolves with the deployment data
* @returns {Promise<Deployment>} Resolves with the deployment data
*/
async function createDeployment(deployment_config) {
async function createDeployment(config) {
const response = await this.request("/deployments", {
method: "POST",
data: deployment_config,
data: config,
});

return response.json();
}

/**
* @typedef {Object} DeploymentUpdateRequest - Request body for `deployments.update`
* @property {string} version - the 64-character string ID of the model version that you want to deploy
* @property {string} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()`
* @property {number} min_instances - the minimum number of instances for scaling
* @property {number} max_instances - the maximum number of instances for scaling
* @property {string=} version - the 64-character string ID of the model version that you want to deploy
* @property {string=} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()`
* @property {number=} min_instances - the minimum number of instances for scaling
* @property {number=} max_instances - the maximum number of instances for scaling
*/

/**
* Update an existing deployment
*
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes.
* @returns {Promise<object>} Resolves with the deployment data
* @param {DeploymentUpdateRequest | {version: string} | {hardware: string} | {min_instances: number} | {max_instance: number}} deployment_config - Required. The deployment changes.
* @returns {Promise<Deployment>} Resolves with the deployment data
*/
async function updateDeployment(
deployment_owner,
@@ -117,7 +124,7 @@ async function updateDeployment(
/**
* List all deployments
*
* @returns {Promise<object>} - Resolves with a page of deployments
* @returns {Promise<Page<Deployment>>} - Resolves with a page of deployments
*/
async function listDeployments() {
const response = await this.request("/deployments", {
3 changes: 2 additions & 1 deletion lib/hardware.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/** @typedef {import("./types").Hardware} Hardware */
/**
* List hardware
*
* @returns {Promise<object[]>} Resolves with the array of hardware
* @returns {Promise<Hardware[]>} Resolves with the array of hardware
*/
async function listHardware() {
const response = await this.request("/hardware", {
12 changes: 6 additions & 6 deletions lib/identifier.js
Original file line number Diff line number Diff line change
@@ -2,21 +2,21 @@
* A reference to a model version in the format `owner/name` or `owner/name:version`.
*/
class ModelVersionIdentifier {
/*
* @param {string} Required. The model owner.
* @param {string} Required. The model name.
* @param {string} The model version.
/**
* @param {string} owner Required. The model owner.
* @param {string} name Required. The model name.
* @param {string | null=} version The model version.
*/
constructor(owner, name, version = null) {
this.owner = owner;
this.name = name;
this.version = version;
}

/*
/**
* Parse a reference to a model version
*
* @param {string}
* @param {string} ref
* @returns {ModelVersionIdentifier}
* @throws {Error} If the reference is invalid.
*/
29 changes: 19 additions & 10 deletions lib/models.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
/** @typedef {import("./types").Model} Model */
/** @typedef {import("./types").ModelVersion} ModelVersion */
/** @typedef {import("./types").Prediction} Prediction */
/** @typedef {import("./types").Visibility} Visibility */
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/

/**
* Get information about a model
*
* @param {string} model_owner - Required. The name of the user or organization that owns the model
* @param {string} model_name - Required. The name of the model
* @returns {Promise<object>} Resolves with the model data
* @returns {Promise<Model>} Resolves with the model data
*/
async function getModel(model_owner, model_name) {
const response = await this.request(`/models/${model_owner}/${model_name}`, {
@@ -18,7 +27,7 @@ async function getModel(model_owner, model_name) {
*
* @param {string} model_owner - Required. The name of the user or organization that owns the model
* @param {string} model_name - Required. The name of the model
* @returns {Promise<object>} Resolves with the list of model versions
* @returns {Promise<Page<ModelVersion>>} Resolves with the list of model versions
*/
async function listModelVersions(model_owner, model_name) {
const response = await this.request(
@@ -37,7 +46,7 @@ async function listModelVersions(model_owner, model_name) {
* @param {string} model_owner - Required. The name of the user or organization that owns the model
* @param {string} model_name - Required. The name of the model
* @param {string} version_id - Required. The model version
* @returns {Promise<object>} Resolves with the model version data
* @returns {Promise<ModelVersion>} Resolves with the model version data
*/
async function getModelVersion(model_owner, model_name, version_id) {
const response = await this.request(
@@ -53,7 +62,7 @@ async function getModelVersion(model_owner, model_name, version_id) {
/**
* List all public models
*
* @returns {Promise<object>} Resolves with the model version data
* @returns {Promise<Page<Model>>} Resolves with the model version data
*/
async function listModels() {
const response = await this.request("/models", {
@@ -69,14 +78,14 @@ async function listModels() {
* @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
* @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization.
* @param {object} options
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
* @param {Visibility} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
* @param {string} options.description - A description of the model.
* @param {string} options.github_url - A URL for the model's source code on GitHub.
* @param {string} options.paper_url - A URL for the model's paper.
* @param {string} options.license_url - A URL for the model's license.
* @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file.
* @returns {Promise<object>} Resolves with the model version data
* @param {string=} options.github_url - A URL for the model's source code on GitHub.
* @param {string=} options.paper_url - A URL for the model's paper.
* @param {string=} options.license_url - A URL for the model's license.
* @param {string=} options.cover_image_url - A URL for the model's cover image. This should be an image file.
* @returns {Promise<Model>} Resolves with the model version data
*/
async function createModel(model_owner, model_name, options) {
const data = { owner: model_owner, name: model_name, ...options };
40 changes: 28 additions & 12 deletions lib/predictions.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/

/**
* @typedef {import("./types").Prediction} Prediction
* @typedef {Object} BasePredictionOptions
* @property {unknown} input - Required. An object with the model inputs
* @property {string} [webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @property {string[]} [webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @property {boolean} [stream] - Whether to stream the prediction output. Defaults to false
*
* @typedef {Object} ModelPredictionOptions
* @property {string} model The model name (for official models)
* @property {never=} version
*
* @typedef {Object} VersionPredictionOptions
* @property {string} version The model version
* @property {never=} model
*/

const { transformFileInputs } = require("./util");

/**
* Create a new prediction
*
* @param {object} options
* @param {string} options.model - The model.
* @param {string} options.version - The model version.
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @returns {Promise<object>} Resolves with the created prediction
* @param {BasePredictionOptions & (ModelPredictionOptions | VersionPredictionOptions)} options
* @returns {Promise<Prediction>} Resolves with the created prediction
*/
async function createPrediction(options) {
const { model, version, stream, input, ...data } = options;
@@ -54,8 +70,8 @@ async function createPrediction(options) {
/**
* Fetch a prediction by ID
*
* @param {number} prediction_id - Required. The prediction ID
* @returns {Promise<object>} Resolves with the prediction data
* @param {string} prediction_id - Required. The prediction ID
* @returns {Promise<Prediction>} Resolves with the prediction data
*/
async function getPrediction(prediction_id) {
const response = await this.request(`/predictions/${prediction_id}`, {
@@ -69,7 +85,7 @@ async function getPrediction(prediction_id) {
* Cancel a prediction by ID
*
* @param {string} prediction_id - Required. The training ID
* @returns {Promise<object>} Resolves with the data for the training
* @returns {Promise<Prediction>} Resolves with the data for the training
*/
async function cancelPrediction(prediction_id) {
const response = await this.request(`/predictions/${prediction_id}/cancel`, {
@@ -82,7 +98,7 @@ async function cancelPrediction(prediction_id) {
/**
* List all predictions
*
* @returns {Promise<object>} - Resolves with a page of predictions
* @returns {Promise<Page<Prediction>>} - Resolves with a page of predictions
*/
async function listPredictions() {
const response = await this.request("/predictions", {
2 changes: 1 addition & 1 deletion lib/stream.js
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ class ServerSentEvent {
*
* @param {object} config
* @param {string} config.url The URL to connect to.
* @param {typeof fetch} [config.fetch] The URL to connect to.
* @param {(url: URL | RequestInfo, init?: RequestInit | undefined) => Promise<Response>} [config.fetch] The URL to connect to.
* @param {object} [config.options] The EventSource options.
* @returns {ReadableStream<ServerSentEvent> & AsyncIterable<ServerSentEvent>}
*/
16 changes: 11 additions & 5 deletions lib/trainings.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/
/** @typedef {import("./types").Training} Training */

/**
* Create a new training
*
@@ -6,10 +12,10 @@
* @param {string} version_id - Required. The version ID
* @param {object} options
* @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}"
* @param {object} options.input - Required. An object with the model inputs
* @param {unknown} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @returns {Promise<object>} Resolves with the data for the created training
* @returns {Promise<Training>} Resolves with the data for the created training
*/
async function createTraining(model_owner, model_name, version_id, options) {
const { ...data } = options;
@@ -38,7 +44,7 @@ async function createTraining(model_owner, model_name, version_id, options) {
* Fetch a training by ID
*
* @param {string} training_id - Required. The training ID
* @returns {Promise<object>} Resolves with the data for the training
* @returns {Promise<Training>} Resolves with the data for the training
*/
async function getTraining(training_id) {
const response = await this.request(`/trainings/${training_id}`, {
@@ -52,7 +58,7 @@ async function getTraining(training_id) {
* Cancel a training by ID
*
* @param {string} training_id - Required. The training ID
* @returns {Promise<object>} Resolves with the data for the training
* @returns {Promise<Training>} Resolves with the data for the training
*/
async function cancelTraining(training_id) {
const response = await this.request(`/trainings/${training_id}/cancel`, {
@@ -65,7 +71,7 @@ async function cancelTraining(training_id) {
/**
* List all trainings
*
* @returns {Promise<object>} - Resolves with a page of trainings
* @returns {Promise<Page<Training>>} - Resolves with a page of trainings
*/
async function listTrainings() {
const response = await this.request("/trainings", {
91 changes: 91 additions & 0 deletions lib/types.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/**
* @typedef {"starting" | "processing" | "succeeded" | "failed" | "canceled"} Status
* @typedef {"public" | "private"} Visibility
* @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType
*
* @typedef {Object} Account
* @property {"user" | "organization"} type
* @property {string} username
* @property {string} name
* @property {string=} github_url
*
* @typedef {Object} Collection
* @property {string} name
* @property {string} slug
* @property {string} description
* @property {Model[]=} models
*
* @typedef {Object} Deployment
* @property {string} owner
* @property {string} name
* @property {object} current_release
* @property {number} current_release.number
* @property {string} current_release.model
* @property {string} current_release.version
* @property {string} current_release.created_at
* @property {Account} current_release.created_by
* @property {object} current_release.configuration
* @property {string} current_release.configuration.hardware
* @property {number} current_release.configuration.min_instances
* @property {number} current_release.configuration.max_instances
*
* @typedef {Object} Hardware
* @property {string} sku
* @property {string} name
*
* @typedef {Object} Model
* @property {string} url
* @property {string} owner
* @property {string} name
* @property {string=} description
* @property {Visibility} visibility
* @property {string=} github_url
* @property {string=} paper_url
* @property {string=} license_url
* @property {number} run_count
* @property {string=} cover_image_url
* @property {Prediction=} default_example
* @property {ModelVersion=} latest_version
*
* @typedef {Object} ModelVersion
* @property {string} id
* @property {string} created_at
* @property {string} cog_version
* @property {string} openapi_schema
*
* @typedef {Object} Prediction
* @property {string} id
* @property {Status} status
* @property {string=} model
* @property {string} version
* @property {object} input
* @property {unknown=} output
* @property {"api" | "web"} source
* @property {unknown=} error
* @property {string=} logs
* @property {{predict_time?: number}=} metrics
* @property {string=} webhook
* @property {WebhookEventType[]=} webhook_events_filter
* @property {string} created_at
* @property {string=} started_at
* @property {string=} completed_at
* @property {{get: string; cancel: string; stream?: string}} urls
*
* @typedef {Prediction} Training
*
* @typedef {Object} ServerSentEvent
* @property {string} event
* @property {string} data
* @property {string=} id
* @property {number=} retry
*/

/**
* @template T
* @typedef {Object} Page
* @property {string=} previous
* @property {string=} next
* @property {T[]} results
*/

module.exports = {};
31 changes: 8 additions & 23 deletions lib/util.js
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
const ApiError = require("./error");

/**
* @see {@link validateWebhook}
* @overload
* @param {object} requestData - The request data
* @param {string} requestData.id - The webhook ID header from the incoming request.
* @param {string} requestData.timestamp - The webhook timestamp header from the incoming request.
* @param {string} requestData.body - The raw body of the incoming webhook request.
* @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method.
* @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures.
*/

/**
* @see {@link validateWebhook}
* @overload
* @param {object} requestData - The request object
* @param {object} requestData.headers - The request headers
* @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request
* @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request
* @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures
* @param {string} requestData.body - The raw body of the incoming webhook request
* @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method
*/

/**
* Validate a webhook signature
*
* @typedef {Object} WebhookPayload
* @property {string} id - The webhook ID header from the incoming request.
* @property {string} timestamp - The webhook timestamp header from the incoming request.
* @property {string} body - The raw body of the incoming webhook request.
* @property {string} signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures.
*
* @param {Request | WebhookPayload} requestData
* @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method.
* @returns {Promise<boolean>} - True if the signature is valid
* @throws {Error} - If the request is missing required headers, body, or secret
*/
2 changes: 1 addition & 1 deletion lib/webhooks.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Get the default webhook signing secret
*
* @returns {Promise<object>} Resolves with the signing secret for the default webhook
* @returns {Promise<{key: string}>} Resolves with the signing secret for the default webhook
*/
async function getDefaultWebhookSecret() {
const response = await this.request("/webhooks/default/secret", {
20 changes: 11 additions & 9 deletions package.json
Original file line number Diff line number Diff line change
@@ -8,11 +8,12 @@
"license": "Apache-2.0",
"main": "index.js",
"type": "commonjs",
"types": "index.d.ts",
"types": "dist/types/index.d.ts",
"files": [
"CONTRIBUTING.md",
"LICENSE",
"README.md",
"dist/**/*",
"index.d.ts",
"index.js",
"lib/**/*.js",
@@ -26,21 +27,22 @@
"yarn": ">=1.7.0"
},
"scripts": {
"build": "npm run build:types && tsc --noEmit dist/types/**/*.d.ts",
"build:types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js",
"check": "tsc",
"format": "biome format . --write",
"lint-biome": "biome lint .",
"lint-publint": "publint",
"lint": "npm run lint-biome && npm run lint-publint",
"test": "jest"
},
"optionalDependencies": {
"readable-stream": ">=4.0.0"
"lint": "biome lint .",
"lint:integration": "npm run build; publint",
"lint:all": "npm run tsc; npm run lint; npm run lint:integration",
"prepack": "npm run build",
"test": "jest",
"test:integration": "npm run build; for x in commonjs esm typescript; do npm --prefix integration/$x install --omit=dev && npm --prefix integration/$x test; done;",
"test:all": "npm run check; npm run test; npm run test:integration"
},
"devDependencies": {
"@biomejs/biome": "^1.4.1",
"@types/jest": "^29.5.3",
"@typescript-eslint/eslint-plugin": "^5.56.0",
"cross-fetch": "^3.1.5",
"jest": "^29.6.2",
"nock": "^14.0.0-beta.4",
"publint": "^0.2.7",
7 changes: 4 additions & 3 deletions tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
{
"compilerOptions": {
"allowJs": true,
"esModuleInterop": true,
"noEmit": true,
"strict": true,
"allowJs": true
"strict": true
},
"exclude": ["**/node_modules", "integration"]
"types": ["node"],
"exclude": ["dist", "integration", "**/node_modules"]
}