1
+ from __future__ import annotations
2
+
1
3
import abc
2
4
from contextlib import suppress
5
+ from typing import Any , Callable
3
6
4
7
import cloudpickle
5
8
6
9
from adaptive .utils import _RequireAttrsABCMeta , load , save
7
10
8
11
9
- def uses_nth_neighbors (n : int ):
12
+ def uses_nth_neighbors (n : int ) -> Callable :
10
13
"""Decorator to specify how many neighboring intervals the loss function uses.
11
14
12
15
Wraps loss functions to indicate that they expect intervals together
@@ -82,10 +85,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
82
85
"""
83
86
84
87
data : dict
85
- npoints : int
86
88
pending_points : set
89
+ function : Callable
90
+
91
+ @property
92
+ @abc .abstractmethod
93
+ def npoints (self ) -> int :
94
+ """Number of learned points."""
87
95
88
- def tell (self , x , y ):
96
+ def tell (self , x : Any , y ) -> None :
89
97
"""Tell the learner about a single value.
90
98
91
99
Parameters
@@ -95,7 +103,7 @@ def tell(self, x, y):
95
103
"""
96
104
self .tell_many ([x ], [y ])
97
105
98
- def tell_many (self , xs , ys ) :
106
+ def tell_many (self , xs : Any , ys : Any ) -> None :
99
107
"""Tell the learner about some values.
100
108
101
109
Parameters
@@ -116,7 +124,7 @@ def remove_unfinished(self):
116
124
"""Remove uncomputed data from the learner."""
117
125
118
126
@abc .abstractmethod
119
- def loss (self , real = True ):
127
+ def loss (self , real : bool = True ) -> float :
120
128
"""Return the loss for the current state of the learner.
121
129
122
130
Parameters
@@ -128,7 +136,7 @@ def loss(self, real=True):
128
136
"""
129
137
130
138
@abc .abstractmethod
131
- def ask (self , n , tell_pending = True ):
139
+ def ask (self , n : int , tell_pending : bool = True ):
132
140
"""Choose the next 'n' points to evaluate.
133
141
134
142
Parameters
@@ -146,7 +154,7 @@ def _get_data(self):
146
154
pass
147
155
148
156
@abc .abstractmethod
149
- def _set_data (self ):
157
+ def _set_data (self , data : Any ):
150
158
pass
151
159
152
160
@abc .abstractmethod
@@ -164,7 +172,7 @@ def copy_from(self, other):
164
172
"""
165
173
self ._set_data (other ._get_data ())
166
174
167
- def save (self , fname , compress = True ):
175
+ def save (self , fname : str , compress : bool = True ) -> None :
168
176
"""Save the data of the learner into a pickle file.
169
177
170
178
Parameters
@@ -178,7 +186,7 @@ def save(self, fname, compress=True):
178
186
data = self ._get_data ()
179
187
save (fname , data , compress )
180
188
181
- def load (self , fname , compress = True ):
189
+ def load (self , fname : str , compress : bool = True ) -> None :
182
190
"""Load the data of a learner from a pickle file.
183
191
184
192
Parameters
@@ -193,8 +201,8 @@ def load(self, fname, compress=True):
193
201
data = load (fname , compress )
194
202
self ._set_data (data )
195
203
196
- def __getstate__ (self ):
204
+ def __getstate__ (self ) -> dict [ str , Any ] :
197
205
return cloudpickle .dumps (self .__dict__ )
198
206
199
- def __setstate__ (self , state ) :
207
+ def __setstate__ (self , state : dict [ str , Any ]) -> None :
200
208
self .__dict__ = cloudpickle .loads (state )
0 commit comments