Skip to content

Commit 709e73e

Browse files
committed
Optionally automatically cancel predictions with fatal errors
If our application relies on the all-in-one predict method, there's a good chance we don't want to be stuck with a running prediction we're paying for but won't ever use.
1 parent c56c35b commit 709e73e

File tree

1 file changed

+48
-24
lines changed

1 file changed

+48
-24
lines changed

lib/Model.js

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,16 @@ export default class Model extends ReplicateObject {
4848

4949
async predict(
5050
input,
51-
{ onUpdate = noop, onTemporaryError = noop } = {},
51+
{
52+
onUpdate = noop,
53+
onTemporaryError = noop,
54+
onCancel = noop,
55+
onCancelError = noop,
56+
} = {},
5257
{
5358
defaultPollingInterval = 500,
5459
backoffFn = (errorCount) => Math.pow(2, errorCount) * 100,
60+
cancelOnFatalError = false,
5561
} = {}
5662
) {
5763
if (!input) {
@@ -60,39 +66,57 @@ export default class Model extends ReplicateObject {
6066

6167
let prediction = await this.createPrediction(input);
6268

63-
onUpdate(prediction);
69+
try {
70+
onUpdate(prediction);
6471

65-
let pollingInterval = defaultPollingInterval;
66-
let errorCount = 0;
72+
let pollingInterval = defaultPollingInterval;
73+
let errorCount = 0;
6774

68-
while (!prediction.hasTerminalStatus()) {
69-
await sleep(pollingInterval);
70-
pollingInterval = defaultPollingInterval; // Reset to default each time.
75+
while (!prediction.hasTerminalStatus()) {
76+
await sleep(pollingInterval);
77+
pollingInterval = defaultPollingInterval; // Reset to default each time.
7178

72-
try {
73-
prediction = await this.client.prediction(prediction.id).load();
79+
try {
80+
prediction = await this.client.prediction(prediction.id).load();
7481

75-
onUpdate(prediction);
82+
onUpdate(prediction);
7683

77-
errorCount = 0; // Reset because we've had a non-error response.
78-
} catch (err) {
79-
if (!err instanceof ReplicateResponseError) {
80-
throw err;
81-
}
84+
errorCount = 0; // Reset because we've had a non-error response.
85+
} catch (err) {
86+
if (!err instanceof ReplicateResponseError) {
87+
throw err;
88+
}
8289

83-
if (
84-
!err.status ||
85-
(Math.floor(err.status / 100) !== 5 && err.status !== 429)
86-
) {
87-
throw err;
88-
}
90+
if (
91+
!err.status ||
92+
(Math.floor(err.status / 100) !== 5 && err.status !== 429)
93+
) {
94+
throw err;
95+
}
8996

90-
errorCount += 1;
97+
errorCount += 1;
9198

92-
onTemporaryError(err);
99+
onTemporaryError(err);
93100

94-
pollingInterval = backoffFn(errorCount);
101+
pollingInterval = backoffFn(errorCount);
102+
}
103+
}
104+
} catch (err) {
105+
if (cancelOnFatalError) {
106+
// We intentionally don't await this, so we don't block.
107+
prediction
108+
.cancel()
109+
.catch((e) => {
110+
onCancelError(e);
111+
112+
throw e;
113+
})
114+
.then(() => {
115+
onCancel();
116+
});
95117
}
118+
119+
throw err;
96120
}
97121

98122
return prediction;

0 commit comments

Comments
 (0)