Skip to content

Commit 3523343

Browse files
committed
Implement "sequential" strategy for BalancingLearner
1 parent fefe366 commit 3523343

File tree

2 files changed

+71
-16
lines changed

2 files changed

+71
-16
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def dispatch(child_functions: list[Callable], arg: Any) -> Any:
3636
return child_functions[index](x)
3737

3838

39-
STRATEGY_TYPE: TypeAlias = Literal["loss_improvements", "loss", "npoints", "cycle"]
39+
STRATEGY_TYPE: TypeAlias = Literal[
40+
"loss_improvements", "loss", "npoints", "cycle", "sequential"
41+
]
4042

4143
CDIMS_TYPE: TypeAlias = Union[
4244
Sequence[dict[str, Any]],
@@ -77,13 +79,21 @@ class BalancingLearner(BaseLearner):
7779
function : callable
7880
A function that calls the functions of the underlying learners.
7981
Its signature is ``function(learner_index, point)``.
80-
strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
82+
strategy : 'loss_improvements' (default), 'loss', 'npoints', 'cycle', or 'sequential'
8183
The points that the `BalancingLearner` choses can be either based on:
82-
the best 'loss_improvements', the smallest total 'loss' of the
83-
child learners, the number of points per learner, using 'npoints',
84-
or by cycling through the learners one by one using 'cycle'.
85-
One can dynamically change the strategy while the simulation is
86-
running by changing the ``learner.strategy`` attribute.
84+
85+
- 'loss_improvements': This strategy selects the points with the best
86+
improvement in loss.
87+
- 'loss': This strategy selects the points with the smallest total loss
88+
from the child learners.
89+
- 'npoints': This strategy selects points based on the number of points
90+
per learner.
91+
- 'cycle': This strategy cycles through all learners one by one.
92+
- 'sequential': This strategy goes through learners in a sequential
93+
order. Only works with learners that have a `done` method.
94+
95+
You can change the strategy dynamically while the simulation is
96+
running by modifying the `learner.strategy` attribute.
8797
8898
Notes
8999
-----
@@ -159,13 +169,19 @@ def nsamples(self):
159169

160170
@property
161171
def strategy(self) -> STRATEGY_TYPE:
162-
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
163-
'cycle'. The points that the `BalancingLearner` choses can be either
164-
based on: the best 'loss_improvements', the smallest total 'loss' of
165-
the child learners, the number of points per learner, using 'npoints',
166-
or by going through all learners one by one using 'cycle'.
167-
One can dynamically change the strategy while the simulation is
168-
running by changing the ``learner.strategy`` attribute."""
172+
"""The `BalancingLearner` can choose points based on different strategies.
173+
174+
The strategies are:
175+
176+
- 'loss_improvements': This strategy selects the points with the best improvement in loss.
177+
- 'loss': This strategy selects the points with the smallest total loss from the child learners.
178+
- 'npoints': This strategy selects points based on the number of points per learner.
179+
- 'cycle': This strategy cycles through all learners one by one.
180+
- 'sequential': This strategy goes through learners in a sequential order.
181+
182+
You can change the strategy dynamically while the simulation is
183+
running by modifying the `learner.strategy` attribute.
184+
"""
169185
return self._strategy
170186

171187
@strategy.setter
@@ -180,6 +196,9 @@ def strategy(self, strategy: STRATEGY_TYPE) -> None:
180196
elif strategy == "cycle":
181197
self._ask_and_tell = self._ask_and_tell_based_on_cycle
182198
self._cycle = itertools.cycle(range(len(self.learners)))
199+
elif strategy == "sequential":
200+
self._ask_and_tell = self._ask_and_tell_based_on_sequential
201+
...
183202
else:
184203
raise ValueError(
185204
'Only strategy="loss_improvements", strategy="loss",'
@@ -255,7 +274,8 @@ def _ask_and_tell_based_on_npoints(
255274
def _ask_and_tell_based_on_cycle(
256275
self, n: int
257276
) -> tuple[list[tuple[Int, Any]], list[float]]:
258-
points, loss_improvements = [], []
277+
points: list[tuple[Int, Any]] = []
278+
loss_improvements: list[float] = []
259279
for _ in range(n):
260280
index = next(self._cycle)
261281
point, loss_improvement = self.learners[index].ask(n=1)
@@ -265,6 +285,33 @@ def _ask_and_tell_based_on_cycle(
265285

266286
return points, loss_improvements
267287

288+
def _ask_and_tell_based_on_sequential(
289+
self, n: int
290+
) -> tuple[list[tuple[Int, Any]], list[float]]:
291+
points: list[tuple[Int, Any]] = []
292+
loss_improvements: list[float] = []
293+
learner_index = 0
294+
295+
while len(points) < n:
296+
learner = self.learners[learner_index]
297+
if learner.done(): # type: ignore[attr-defined]
298+
if learner_index == len(self.learners) - 1:
299+
break
300+
learner_index += 1
301+
continue
302+
303+
point, loss_improvement = learner.ask(n=1)
304+
if not point: # if learner is exhausted, we don't get points
305+
if learner_index == len(self.learners) - 1:
306+
break
307+
learner_index += 1
308+
continue
309+
points.append((learner_index, point[0]))
310+
loss_improvements.append(loss_improvement[0])
311+
self.tell_pending((learner_index, point[0]))
312+
313+
return points, loss_improvements
314+
268315
def ask(
269316
self, n: int, tell_pending: bool = True
270317
) -> tuple[list[tuple[Int, Any]], list[float]]:

adaptive/tests/test_balancing_learner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from adaptive.learner import BalancingLearner, Learner1D
5+
from adaptive.learner import BalancingLearner, Learner1D, SequenceLearner
66
from adaptive.runner import simple
77

88
strategies = ["loss", "loss_improvements", "npoints", "cycle"]
@@ -64,3 +64,11 @@ def test_strategies(strategy, goal_type, goal):
6464
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
6565
learner = BalancingLearner(learners, strategy=strategy)
6666
simple(learner, **{goal_type: goal})
67+
68+
69+
def test_sequential_strategy() -> None:
70+
learners = [SequenceLearner(lambda x: x, sequence=[0, 1, 2, 3]) for i in range(10)]
71+
learner = BalancingLearner(learners, strategy="sequential") # type: ignore[arg-type]
72+
simple(learner, goal=lambda lrn: sum(x.npoints for x in lrn.learners) >= 4 * 5)
73+
assert all(lrn.done() for lrn in learners[:5])
74+
assert all(not lrn.done() for lrn in learners[5:])

0 commit comments

Comments
 (0)