From 87b5e07499ea56e58d559608c029989a19d4051b Mon Sep 17 00:00:00 2001
From: Aron Carroll <aron@replicate.com>
Date: Thu, 2 May 2024 12:53:05 +0100
Subject: [PATCH 1/2] Extract `signal` property from the `run()` options

This prevents it being passed to the backend as part of the body. The
backend has recently started validating the body payload so this is now
resulting as an API error.

Fixes #249
---
 index.js      |  4 +---
 index.test.ts | 10 ++++++++--
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/index.js b/index.js
index ce5d502..74f667b 100644
--- a/index.js
+++ b/index.js
@@ -129,7 +129,7 @@ class Replicate {
    * @returns {Promise<object>} - Resolves with the output of running the model
    */
   async run(ref, options, progress) {
-    const { wait, ...data } = options;
+    const { wait, signal, ...data } = options;
 
     const identifier = ModelVersionIdentifier.parse(ref);
 
@@ -153,8 +153,6 @@ class Replicate {
       progress(prediction);
     }
 
-    const { signal } = options;
-
     prediction = await this.wait(
       prediction,
       wait || {},
diff --git a/index.test.ts b/index.test.ts
index d2f064c..f62c7bc 100644
--- a/index.test.ts
+++ b/index.test.ts
@@ -1233,11 +1233,14 @@ describe("Replicate client", () => {
     test("Aborts the operation when abort signal is invoked", async () => {
       const controller = new AbortController();
       const { signal } = controller;
+      let body: Record<string, unknown> | undefined;
 
       const scope = nock(BASE_URL)
-        .post("/predictions", (body) => {
+        .post("/predictions", (_body) => {
+          // Should not pass the signal object in the body.
+          body = _body;
           controller.abort();
-          return body;
+          return _body;
         })
         .reply(201, {
           id: "ufawqhfynnddngldkgtslldrkq",
@@ -1263,7 +1266,10 @@ describe("Replicate client", () => {
         }
       );
 
+      expect(body).toBeDefined();
+      expect(body?.["signal"]).toBeUndefined();
       expect(signal.aborted).toBe(true);
+      expect(output).toBeUndefined();
 
       scope.done();
     });

From 2984e54977dc95455162975c68428213a8cb1e8b Mon Sep 17 00:00:00 2001
From: Aron Carroll <aron@replicate.com>
Date: Thu, 2 May 2024 12:54:41 +0100
Subject: [PATCH 2/2] Call the `onProgress` handler with the canceled
 prediction

Previously when aborting a `run()` request we were dropping the final
canceled prediction object and calling the `onProgress` callback with a
stale "processing" object.
---
 index.js      |  8 ++++++--
 index.test.ts | 25 +++++++++++++++++++++++--
 2 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/index.js b/index.js
index 74f667b..79b8f4d 100644
--- a/index.js
+++ b/index.js
@@ -162,8 +162,8 @@ class Replicate {
           progress(updatedPrediction);
         }
 
-        if (signal && signal.aborted) {
-          await this.predictions.cancel(updatedPrediction.id);
+        // We handle the cancel later in the function.
+        if (signal?.aborted) {
           return true; // stop polling
         }
 
@@ -171,6 +171,10 @@ class Replicate {
       }
     );
 
+    if (signal?.aborted) {
+      prediction = await this.predictions.cancel(prediction.id);
+    }
+
     // Call progress callback with the completed prediction object
     if (progress) {
       progress(prediction);
diff --git a/index.test.ts b/index.test.ts
index f62c7bc..53737e0 100644
--- a/index.test.ts
+++ b/index.test.ts
@@ -1258,12 +1258,14 @@ describe("Replicate client", () => {
           status: "canceled",
         });
 
-      await client.run(
+      const onProgress = jest.fn();
+      const output = await client.run(
         "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
         {
           input: { text: "Hello, world!" },
           signal,
-        }
+        },
+        onProgress
       );
 
       expect(body).toBeDefined();
@@ -1271,6 +1273,25 @@ describe("Replicate client", () => {
       expect(signal.aborted).toBe(true);
       expect(output).toBeUndefined();
 
+      expect(onProgress).toHaveBeenNthCalledWith(
+        1,
+        expect.objectContaining({
+          status: "processing",
+        })
+      );
+      expect(onProgress).toHaveBeenNthCalledWith(
+        2,
+        expect.objectContaining({
+          status: "processing",
+        })
+      );
+      expect(onProgress).toHaveBeenNthCalledWith(
+        3,
+        expect.objectContaining({
+          status: "canceled",
+        })
+      );
+
       scope.done();
     });
   });