@@ -36,7 +36,9 @@ def dispatch(child_functions: list[Callable], arg: Any) -> Any:
36
36
return child_functions [index ](x )
37
37
38
38
39
- STRATEGY_TYPE : TypeAlias = Literal ["loss_improvements" , "loss" , "npoints" , "cycle" ]
39
+ STRATEGY_TYPE : TypeAlias = Literal [
40
+ "loss_improvements" , "loss" , "npoints" , "cycle" , "sequential"
41
+ ]
40
42
41
43
CDIMS_TYPE : TypeAlias = Union [
42
44
Sequence [dict [str , Any ]],
@@ -77,13 +79,21 @@ class BalancingLearner(BaseLearner):
77
79
function : callable
78
80
A function that calls the functions of the underlying learners.
79
81
Its signature is ``function(learner_index, point)``.
80
- strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
82
+ strategy : 'loss_improvements' (default), 'loss', 'npoints', 'cycle', or 'sequential'
81
83
The points that the `BalancingLearner` choses can be either based on:
82
- the best 'loss_improvements', the smallest total 'loss' of the
83
- child learners, the number of points per learner, using 'npoints',
84
- or by cycling through the learners one by one using 'cycle'.
85
- One can dynamically change the strategy while the simulation is
86
- running by changing the ``learner.strategy`` attribute.
84
+
85
+ - 'loss_improvements': This strategy selects the points with the best
86
+ improvement in loss.
87
+ - 'loss': This strategy selects the points with the smallest total loss
88
+ from the child learners.
89
+ - 'npoints': This strategy selects points based on the number of points
90
+ per learner.
91
+ - 'cycle': This strategy cycles through all learners one by one.
92
+ - 'sequential': This strategy goes through learners in a sequential
93
+ order. Only works with learners that have a `done` method.
94
+
95
+ You can change the strategy dynamically while the simulation is
96
+ running by modifying the `learner.strategy` attribute.
87
97
88
98
Notes
89
99
-----
@@ -159,13 +169,19 @@ def nsamples(self):
159
169
160
170
@property
161
171
def strategy (self ) -> STRATEGY_TYPE :
162
- """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
163
- 'cycle'. The points that the `BalancingLearner` choses can be either
164
- based on: the best 'loss_improvements', the smallest total 'loss' of
165
- the child learners, the number of points per learner, using 'npoints',
166
- or by going through all learners one by one using 'cycle'.
167
- One can dynamically change the strategy while the simulation is
168
- running by changing the ``learner.strategy`` attribute."""
172
+ """The `BalancingLearner` can choose points based on different strategies.
173
+
174
+ The strategies are:
175
+
176
+ - 'loss_improvements': This strategy selects the points with the best improvement in loss.
177
+ - 'loss': This strategy selects the points with the smallest total loss from the child learners.
178
+ - 'npoints': This strategy selects points based on the number of points per learner.
179
+ - 'cycle': This strategy cycles through all learners one by one.
180
+ - 'sequential': This strategy goes through learners in a sequential order.
181
+
182
+ You can change the strategy dynamically while the simulation is
183
+ running by modifying the `learner.strategy` attribute.
184
+ """
169
185
return self ._strategy
170
186
171
187
@strategy .setter
@@ -180,6 +196,9 @@ def strategy(self, strategy: STRATEGY_TYPE) -> None:
180
196
elif strategy == "cycle" :
181
197
self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
182
198
self ._cycle = itertools .cycle (range (len (self .learners )))
199
+ elif strategy == "sequential" :
200
+ self ._ask_and_tell = self ._ask_and_tell_based_on_sequential
201
+ ...
183
202
else :
184
203
raise ValueError (
185
204
'Only strategy="loss_improvements", strategy="loss",'
@@ -255,7 +274,8 @@ def _ask_and_tell_based_on_npoints(
255
274
def _ask_and_tell_based_on_cycle (
256
275
self , n : int
257
276
) -> tuple [list [tuple [Int , Any ]], list [float ]]:
258
- points , loss_improvements = [], []
277
+ points : list [tuple [Int , Any ]] = []
278
+ loss_improvements : list [float ] = []
259
279
for _ in range (n ):
260
280
index = next (self ._cycle )
261
281
point , loss_improvement = self .learners [index ].ask (n = 1 )
@@ -265,6 +285,33 @@ def _ask_and_tell_based_on_cycle(
265
285
266
286
return points , loss_improvements
267
287
288
+ def _ask_and_tell_based_on_sequential (
289
+ self , n : int
290
+ ) -> tuple [list [tuple [Int , Any ]], list [float ]]:
291
+ points : list [tuple [Int , Any ]] = []
292
+ loss_improvements : list [float ] = []
293
+ learner_index = 0
294
+
295
+ while len (points ) < n :
296
+ learner = self .learners [learner_index ]
297
+ if learner .done (): # type: ignore[attr-defined]
298
+ if learner_index == len (self .learners ) - 1 :
299
+ break
300
+ learner_index += 1
301
+ continue
302
+
303
+ point , loss_improvement = learner .ask (n = 1 )
304
+ if not point : # if learner is exhausted, we don't get points
305
+ if learner_index == len (self .learners ) - 1 :
306
+ break
307
+ learner_index += 1
308
+ continue
309
+ points .append ((learner_index , point [0 ]))
310
+ loss_improvements .append (loss_improvement [0 ])
311
+ self .tell_pending ((learner_index , point [0 ]))
312
+
313
+ return points , loss_improvements
314
+
268
315
def ask (
269
316
self , n : int , tell_pending : bool = True
270
317
) -> tuple [list [tuple [Int , Any ]], list [float ]]:
0 commit comments