@@ -96,9 +96,10 @@ func distributedFetch[V, T any](c *Client[T], key string, fetchFn FetchFn[V]) Fe
96
96
97
97
return func (ctx context.Context ) (V , error ) {
98
98
stale , hasStale := * new (V ), false
99
- bytes , ok := c .distributedStorage .Get (ctx , key )
100
- if ok {
101
- c .reportDistributedCacheHit (true )
99
+ bytes , existsInDistributedStorage := c .distributedStorage .Get (ctx , key )
100
+ c .reportDistributedCacheHit (existsInDistributedStorage )
101
+
102
+ if existsInDistributedStorage {
102
103
record , unmarshalErr := unmarshalRecord [V ](bytes , key , c .log )
103
104
if unmarshalErr != nil {
104
105
return record .Value , unmarshalErr
@@ -116,8 +117,16 @@ func distributedFetch[V, T any](c *Client[T], key string, fetchFn FetchFn[V]) Fe
116
117
stale , hasStale = record .Value , true
117
118
}
118
119
119
- if ! ok {
120
- c .reportDistributedCacheHit (false )
120
+ // Before we call the fetchFn, we'll do an unblocking read to see if the
121
+ // context has been cancelled. If it has, we'll return a stale value if we
122
+ // have one available.
123
+ select {
124
+ case <- ctx .Done ():
125
+ if hasStale {
126
+ return stale , errors .Join (errOnlyDistributedRecords , ctx .Err ())
127
+ }
128
+ return * (new (V )), ctx .Err ()
129
+ default :
121
130
}
122
131
123
132
// If it's not fresh enough, we'll retrieve it from the source.
@@ -146,7 +155,7 @@ func distributedFetch[V, T any](c *Client[T], key string, fetchFn FetchFn[V]) Fe
146
155
147
156
if hasStale {
148
157
c .reportDistributedStaleFallback ()
149
- return stale , nil
158
+ return stale , errors . Join ( errOnlyDistributedRecords , fetchErr )
150
159
}
151
160
152
161
return response , fetchErr
@@ -177,14 +186,14 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet
177
186
idsToRefresh := make ([]string , 0 , len (ids ))
178
187
for _ , id := range ids {
179
188
key := keyFn (id )
180
- bytes , ok := distributedRecords [key ]
181
- if ! ok {
182
- c .reportDistributedCacheHit (false )
189
+ bytes , existsInDistributedStorage := distributedRecords [key ]
190
+ c .reportDistributedCacheHit (existsInDistributedStorage )
191
+
192
+ if ! existsInDistributedStorage {
183
193
idsToRefresh = append (idsToRefresh , id )
184
194
continue
185
195
}
186
196
187
- c .reportDistributedCacheHit (true )
188
197
record , unmarshalErr := unmarshalRecord [V ](bytes , key , c .log )
189
198
if unmarshalErr != nil {
190
199
idsToRefresh = append (idsToRefresh , id )
@@ -194,29 +203,46 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet
194
203
// If early refreshes isn't enabled it means all records are fresh, otherwise we'll check the CreatedAt time.
195
204
if ! c .distributedEarlyRefreshes || c .clock .Since (record .CreatedAt ) < c .distributedRefreshAfterDuration {
196
205
// We never want to return missing records.
197
- if ! record .IsMissingRecord {
198
- fresh [id ] = record .Value
199
- } else {
206
+ if record .IsMissingRecord {
200
207
c .reportDistributedMissingRecord ()
208
+ continue
201
209
}
210
+
211
+ fresh [id ] = record .Value
202
212
continue
203
213
}
204
214
205
215
idsToRefresh = append (idsToRefresh , id )
206
216
c .reportDistributedRefresh ()
207
217
208
218
// We never want to return missing records.
209
- if ! record .IsMissingRecord {
210
- stale [id ] = record .Value
211
- } else {
219
+ if record .IsMissingRecord {
212
220
c .reportDistributedMissingRecord ()
221
+ continue
213
222
}
223
+ stale [id ] = record .Value
214
224
}
215
225
216
226
if len (idsToRefresh ) == 0 {
217
227
return fresh , nil
218
228
}
219
229
230
+ // Before we call the fetchFn, we'll do an unblocking read to see if the
231
+ // context has been cancelled. If it has, we'll return any potential
232
+ // records we got from the distributed storage.
233
+ select {
234
+ case <- ctx .Done ():
235
+ maps .Copy (stale , fresh )
236
+
237
+ // If we didn't get any records from the distributed storage,
238
+ // we'll return the error from the fetch function as-is.
239
+ if len (stale ) < 1 {
240
+ return stale , ctx .Err ()
241
+ }
242
+ return stale , errors .Join (errOnlyDistributedRecords , ctx .Err ())
243
+ default :
244
+ }
245
+
220
246
dataSourceResponses , err := fetchFn (ctx , idsToRefresh )
221
247
// In case of an error, we'll proceed with the ones we got from the distributed storage.
222
248
// NOTE: It's important that we return a specific error here, otherwise we'll potentially
@@ -227,17 +253,22 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet
227
253
c .reportDistributedStaleFallback ()
228
254
}
229
255
maps .Copy (stale , fresh )
230
- return stale , errOnlyDistributedRecords
256
+
257
+ // If we didn't get any records from the distributed storage,
258
+ // we'll return the error from the fetch function as-is.
259
+ if len (stale ) < 1 {
260
+ return dataSourceResponses , err
261
+ }
262
+
263
+ return stale , errors .Join (errOnlyDistributedRecords , err )
231
264
}
232
265
233
266
// Next, we'll want to check if we should change any of the records to be missing or perform deletions.
234
267
recordsToWrite := make (map [string ][]byte , len (dataSourceResponses ))
235
268
keysToDelete := make ([]string , 0 , max (len (idsToRefresh )- len (dataSourceResponses ), 0 ))
236
269
for _ , id := range idsToRefresh {
237
270
key := keyFn (id )
238
- response , ok := dataSourceResponses [id ]
239
-
240
- if ok {
271
+ if response , ok := dataSourceResponses [id ]; ok {
241
272
if recordBytes , marshalErr := marshalRecord [V ](response , c ); marshalErr == nil {
242
273
recordsToWrite [key ] = recordBytes
243
274
}
0 commit comments