@@ -133,6 +133,8 @@ def persist_to(self, commit_request):
133133class WindmillStateReader (object ):
134134 """Reader of raw state from Windmill."""
135135
136+ # The size of Windmill list request responses is capped at this size (or at
137+ # least one list element, if a single such element would exceed this size).
136138 MAX_LIST_BYTES = 8 << 20 # 8MB
137139
138140 def __init__ (self , computation_id , key , work_token , windmill ):
@@ -156,11 +158,8 @@ def fetch_value(self, state_key):
156158 request .requests .extend ([computation_request ])
157159 return self .windmill .GetData (request )
158160
159- def fetch_list (self , state_key ):
161+ def fetch_list (self , state_key , request_token = None ):
160162 """Get the list at given state tag."""
161- # TODO(ccy): refactor to support continuation tokens for paginated reading.
162- # The current implementation returns up to one page of values from the
163- # list from windmill.
164163 request = windmill_pb2 .GetDataRequest ()
165164 computation_request = windmill_pb2 .ComputationGetDataRequest (
166165 computation_id = self .computation_id )
@@ -171,6 +170,7 @@ def fetch_list(self, state_key):
171170 tag = state_key ,
172171 state_family = '' ,
173172 end_timestamp = MAX_TIMESTAMP ,
173+ request_token = request_token or '' ,
174174 fetch_max_bytes = WindmillStateReader .MAX_LIST_BYTES )
175175 computation_request .requests .extend ([keyed_request ])
176176 request .requests .extend ([computation_request ])
@@ -418,21 +418,26 @@ def _get_iter(self):
418418
419419 def _fetch (self ):
420420 """Fetch state from Windmill."""
421- # TODO(ccy): currently, we only look at the first page of the result
422- # since we do not support pagination. We should support pagination.
423421 # TODO(ccy): the Java SDK caches the first page and at the start of each
424422 # page of values, fires off an asynchronous read for the next page. We
425423 # should do this too once we have asynchronous Windmill state reading.
426- result = self .reader .fetch_list (self .state_key )
427- for wrapper in result .data :
428- for datum in wrapper .data :
429- for item in datum .lists :
430- for value in item .values :
431- try :
432- yield decode_value (value .data )
433- except Exception : # pylint: disable=broad-except
434- logging .error ('Could not decode value: %r.' , value .data )
435- yield None
424+ should_fetch_more = True
425+ next_request_token = None
426+ while should_fetch_more :
427+ result = self .reader .fetch_list (self .state_key ,
428+ request_token = next_request_token )
429+ next_request_token = None
430+ for wrapper in result .data :
431+ for datum in wrapper .data :
432+ for item in datum .lists :
433+ next_request_token = item .continuation_token
434+ for value in item .values :
435+ try :
436+ yield decode_value (value .data )
437+ except Exception : # pylint: disable=broad-except
438+ logging .error ('Could not decode value: %r.' , value .data )
439+ yield None
440+ should_fetch_more = next_request_token != '' # pylint: disable=g-explicit-bool-comparison
436441
437442 def add (self , value ):
438443 # Encode the value here to ensure further mutations of the value don't
0 commit comments