diff --git a/index.js b/index.js index ce5d502..79b8f4d 100644 --- a/index.js +++ b/index.js @@ -129,7 +129,7 @@ class Replicate { * @returns {Promise<object>} - Resolves with the output of running the model */ async run(ref, options, progress) { - const { wait, ...data } = options; + const { wait, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); @@ -153,8 +153,6 @@ class Replicate { progress(prediction); } - const { signal } = options; - prediction = await this.wait( prediction, wait || {}, @@ -164,8 +162,8 @@ class Replicate { progress(updatedPrediction); } - if (signal && signal.aborted) { - await this.predictions.cancel(updatedPrediction.id); + // We handle the cancel later in the function. + if (signal?.aborted) { return true; // stop polling } @@ -173,6 +171,10 @@ class Replicate { } ); + if (signal?.aborted) { + prediction = await this.predictions.cancel(prediction.id); + } + // Call progress callback with the completed prediction object if (progress) { progress(prediction); diff --git a/index.test.ts b/index.test.ts index d2f064c..53737e0 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1233,11 +1233,14 @@ describe("Replicate client", () => { test("Aborts the operation when abort signal is invoked", async () => { const controller = new AbortController(); const { signal } = controller; + let body: Record<string, unknown> | undefined; const scope = nock(BASE_URL) - .post("/predictions", (body) => { + .post("/predictions", (_body) => { + // Should not pass the signal object in the body. + body = _body; controller.abort(); - return body; + return _body; }) .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", @@ -1255,15 +1258,39 @@ describe("Replicate client", () => { status: "canceled", }); - await client.run( + const onProgress = jest.fn(); + const output = await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { input: { text: "Hello, world!" }, signal, - } + }, + onProgress ); + expect(body).toBeDefined(); + expect(body?.["signal"]).toBeUndefined(); expect(signal.aborted).toBe(true); + expect(output).toBeUndefined(); + + expect(onProgress).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + status: "processing", + }) + ); + expect(onProgress).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + status: "processing", + }) + ); + expect(onProgress).toHaveBeenNthCalledWith( + 3, + expect.objectContaining({ + status: "canceled", + }) + ); scope.done(); });