diff --git a/index.js b/index.js
index ce5d502..79b8f4d 100644
--- a/index.js
+++ b/index.js
@@ -129,7 +129,7 @@ class Replicate {
    * @returns {Promise<object>} - Resolves with the output of running the model
    */
   async run(ref, options, progress) {
-    const { wait, ...data } = options;
+    const { wait, signal, ...data } = options;
 
     const identifier = ModelVersionIdentifier.parse(ref);
 
@@ -153,8 +153,6 @@ class Replicate {
       progress(prediction);
     }
 
-    const { signal } = options;
-
     prediction = await this.wait(
       prediction,
       wait || {},
@@ -164,8 +162,8 @@ class Replicate {
           progress(updatedPrediction);
         }
 
-        if (signal && signal.aborted) {
-          await this.predictions.cancel(updatedPrediction.id);
+        // We handle the cancel later in the function.
+        if (signal?.aborted) {
           return true; // stop polling
         }
 
@@ -173,6 +171,10 @@ class Replicate {
       }
     );
 
+    if (signal?.aborted) {
+      prediction = await this.predictions.cancel(prediction.id);
+    }
+
     // Call progress callback with the completed prediction object
     if (progress) {
       progress(prediction);
diff --git a/index.test.ts b/index.test.ts
index d2f064c..53737e0 100644
--- a/index.test.ts
+++ b/index.test.ts
@@ -1233,11 +1233,14 @@ describe("Replicate client", () => {
     test("Aborts the operation when abort signal is invoked", async () => {
       const controller = new AbortController();
       const { signal } = controller;
+      let body: Record<string, unknown> | undefined;
 
       const scope = nock(BASE_URL)
-        .post("/predictions", (body) => {
+        .post("/predictions", (_body) => {
+          // Should not pass the signal object in the body.
+          body = _body;
           controller.abort();
-          return body;
+          return _body;
         })
         .reply(201, {
           id: "ufawqhfynnddngldkgtslldrkq",
@@ -1255,15 +1258,39 @@ describe("Replicate client", () => {
           status: "canceled",
         });
 
-      await client.run(
+      const onProgress = jest.fn();
+      const output = await client.run(
         "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
         {
           input: { text: "Hello, world!" },
           signal,
-        }
+        },
+        onProgress
       );
 
+      expect(body).toBeDefined();
+      expect(body?.["signal"]).toBeUndefined();
       expect(signal.aborted).toBe(true);
+      expect(output).toBeUndefined();
+
+      expect(onProgress).toHaveBeenNthCalledWith(
+        1,
+        expect.objectContaining({
+          status: "processing",
+        })
+      );
+      expect(onProgress).toHaveBeenNthCalledWith(
+        2,
+        expect.objectContaining({
+          status: "processing",
+        })
+      );
+      expect(onProgress).toHaveBeenNthCalledWith(
+        3,
+        expect.objectContaining({
+          status: "canceled",
+        })
+      );
 
       scope.done();
     });