Skip to content

Commit 1ace06f

Browse files
committed
Add type-hints to adaptive/learner/base_learner.py
1 parent 1b7e84d commit 1ace06f

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

adaptive/learner/base_learner.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from __future__ import annotations
2+
13
import abc
24
from contextlib import suppress
5+
from typing import Any, Callable
36

47
import cloudpickle
58

69
from adaptive.utils import _RequireAttrsABCMeta, load, save
710

811

9-
def uses_nth_neighbors(n: int):
12+
def uses_nth_neighbors(n: int) -> Callable:
1013
"""Decorator to specify how many neighboring intervals the loss function uses.
1114
1215
Wraps loss functions to indicate that they expect intervals together
@@ -82,10 +85,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8285
"""
8386

8487
data: dict
85-
npoints: int
8688
pending_points: set
89+
function: Callable
90+
91+
@property
92+
@abc.abstractmethod
93+
def npoints(self) -> int:
94+
"""Number of learned points."""
8795

88-
def tell(self, x, y):
96+
def tell(self, x: Any, y) -> None:
8997
"""Tell the learner about a single value.
9098
9199
Parameters
@@ -95,7 +103,7 @@ def tell(self, x, y):
95103
"""
96104
self.tell_many([x], [y])
97105

98-
def tell_many(self, xs, ys):
106+
def tell_many(self, xs: Any, ys: Any) -> None:
99107
"""Tell the learner about some values.
100108
101109
Parameters
@@ -116,7 +124,7 @@ def remove_unfinished(self):
116124
"""Remove uncomputed data from the learner."""
117125

118126
@abc.abstractmethod
119-
def loss(self, real=True):
127+
def loss(self, real: bool = True) -> float:
120128
"""Return the loss for the current state of the learner.
121129
122130
Parameters
@@ -128,7 +136,7 @@ def loss(self, real=True):
128136
"""
129137

130138
@abc.abstractmethod
131-
def ask(self, n, tell_pending=True):
139+
def ask(self, n: int, tell_pending: bool = True):
132140
"""Choose the next 'n' points to evaluate.
133141
134142
Parameters
@@ -146,7 +154,7 @@ def _get_data(self):
146154
pass
147155

148156
@abc.abstractmethod
149-
def _set_data(self):
157+
def _set_data(self, data: Any):
150158
pass
151159

152160
@abc.abstractmethod
@@ -164,7 +172,7 @@ def copy_from(self, other):
164172
"""
165173
self._set_data(other._get_data())
166174

167-
def save(self, fname, compress=True):
175+
def save(self, fname: str, compress: bool = True) -> None:
168176
"""Save the data of the learner into a pickle file.
169177
170178
Parameters
@@ -178,7 +186,7 @@ def save(self, fname, compress=True):
178186
data = self._get_data()
179187
save(fname, data, compress)
180188

181-
def load(self, fname, compress=True):
189+
def load(self, fname: str, compress: bool = True) -> None:
182190
"""Load the data of a learner from a pickle file.
183191
184192
Parameters
@@ -193,8 +201,8 @@ def load(self, fname, compress=True):
193201
data = load(fname, compress)
194202
self._set_data(data)
195203

196-
def __getstate__(self):
204+
def __getstate__(self) -> dict[str, Any]:
197205
return cloudpickle.dumps(self.__dict__)
198206

199-
def __setstate__(self, state):
207+
def __setstate__(self, state: dict[str, Any]) -> None:
200208
self.__dict__ = cloudpickle.loads(state)

0 commit comments

Comments
 (0)