@@ -48,10 +48,16 @@ export default class Model extends ReplicateObject {
48
48
49
49
async predict (
50
50
input ,
51
- { onUpdate = noop , onTemporaryError = noop } = { } ,
51
+ {
52
+ onUpdate = noop ,
53
+ onTemporaryError = noop ,
54
+ onCancel = noop ,
55
+ onCancelError = noop ,
56
+ } = { } ,
52
57
{
53
58
defaultPollingInterval = 500 ,
54
59
backoffFn = ( errorCount ) => Math . pow ( 2 , errorCount ) * 100 ,
60
+ cancelOnFatalError = false ,
55
61
} = { }
56
62
) {
57
63
if ( ! input ) {
@@ -60,39 +66,57 @@ export default class Model extends ReplicateObject {
60
66
61
67
let prediction = await this . createPrediction ( input ) ;
62
68
63
- onUpdate ( prediction ) ;
69
+ try {
70
+ onUpdate ( prediction ) ;
64
71
65
- let pollingInterval = defaultPollingInterval ;
66
- let errorCount = 0 ;
72
+ let pollingInterval = defaultPollingInterval ;
73
+ let errorCount = 0 ;
67
74
68
- while ( ! prediction . hasTerminalStatus ( ) ) {
69
- await sleep ( pollingInterval ) ;
70
- pollingInterval = defaultPollingInterval ; // Reset to default each time.
75
+ while ( ! prediction . hasTerminalStatus ( ) ) {
76
+ await sleep ( pollingInterval ) ;
77
+ pollingInterval = defaultPollingInterval ; // Reset to default each time.
71
78
72
- try {
73
- prediction = await this . client . prediction ( prediction . id ) . load ( ) ;
79
+ try {
80
+ prediction = await this . client . prediction ( prediction . id ) . load ( ) ;
74
81
75
- onUpdate ( prediction ) ;
82
+ onUpdate ( prediction ) ;
76
83
77
- errorCount = 0 ; // Reset because we've had a non-error response.
78
- } catch ( err ) {
79
- if ( ! err instanceof ReplicateResponseError ) {
80
- throw err ;
81
- }
84
+ errorCount = 0 ; // Reset because we've had a non-error response.
85
+ } catch ( err ) {
86
+ if ( ! err instanceof ReplicateResponseError ) {
87
+ throw err ;
88
+ }
82
89
83
- if (
84
- ! err . status ||
85
- ( Math . floor ( err . status / 100 ) !== 5 && err . status !== 429 )
86
- ) {
87
- throw err ;
88
- }
90
+ if (
91
+ ! err . status ||
92
+ ( Math . floor ( err . status / 100 ) !== 5 && err . status !== 429 )
93
+ ) {
94
+ throw err ;
95
+ }
89
96
90
- errorCount += 1 ;
97
+ errorCount += 1 ;
91
98
92
- onTemporaryError ( err ) ;
99
+ onTemporaryError ( err ) ;
93
100
94
- pollingInterval = backoffFn ( errorCount ) ;
101
+ pollingInterval = backoffFn ( errorCount ) ;
102
+ }
103
+ }
104
+ } catch ( err ) {
105
+ if ( cancelOnFatalError ) {
106
+ // We intentionally don't await this, so we don't block.
107
+ prediction
108
+ . cancel ( )
109
+ . catch ( ( e ) => {
110
+ onCancelError ( e ) ;
111
+
112
+ throw e ;
113
+ } )
114
+ . then ( ( ) => {
115
+ onCancel ( ) ;
116
+ } ) ;
95
117
}
118
+
119
+ throw err ;
96
120
}
97
121
98
122
return prediction ;
0 commit comments