1+ from  __future__ import  annotations 
2+ 
13import  abc 
24from  contextlib  import  suppress 
5+ from  typing  import  Any , Callable 
36
47import  cloudpickle 
58
69from  adaptive .utils  import  _RequireAttrsABCMeta , load , save 
710
811
9- def  uses_nth_neighbors (n : int ):
12+ def  uses_nth_neighbors (n : int )  ->   Callable [[ int ],  Callable [[ BaseLearner ],  float ]] :
1013    """Decorator to specify how many neighboring intervals the loss function uses. 
1114
1215    Wraps loss functions to indicate that they expect intervals together 
@@ -53,7 +56,9 @@ def uses_nth_neighbors(n: int):
5356    ...     return loss 
5457    """ 
5558
56-     def  _wrapped (loss_per_interval ):
59+     def  _wrapped (
60+         loss_per_interval : Callable [[BaseLearner ], float ]
61+     ) ->  Callable [[BaseLearner ], float ]:
5762        loss_per_interval .nth_neighbors  =  n 
5863        return  loss_per_interval 
5964
@@ -82,10 +87,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8287    """ 
8388
8489    data : dict 
85-     npoints : int 
8690    pending_points : set 
91+     function : Callable 
92+ 
93+     @property  
94+     @abc .abstractmethod  
95+     def  npoints (self ) ->  int :
96+         """Number of learned points.""" 
8797
88-     def  tell (self , x , y ) :
98+     def  tell (self , x :  Any , y :  Any )  ->   None :
8999        """Tell the learner about a single value. 
90100
91101        Parameters 
@@ -95,7 +105,7 @@ def tell(self, x, y):
95105        """ 
96106        self .tell_many ([x ], [y ])
97107
98-     def  tell_many (self , xs , ys ) :
108+     def  tell_many (self , xs :  Any , ys :  Any )  ->   None :
99109        """Tell the learner about some values. 
100110
101111        Parameters 
@@ -107,16 +117,16 @@ def tell_many(self, xs, ys):
107117            self .tell (x , y )
108118
109119    @abc .abstractmethod  
110-     def  tell_pending (self , x ) :
120+     def  tell_pending (self , x :  Any )  ->   None :
111121        """Tell the learner that 'x' has been requested such 
112122        that it's not suggested again.""" 
113123
114124    @abc .abstractmethod  
115-     def  remove_unfinished (self ):
125+     def  remove_unfinished (self )  ->   None :
116126        """Remove uncomputed data from the learner.""" 
117127
118128    @abc .abstractmethod  
119-     def  loss (self , real = True ):
129+     def  loss (self , real :  bool   =   True )  ->   float :
120130        """Return the loss for the current state of the learner. 
121131
122132        Parameters 
@@ -128,7 +138,7 @@ def loss(self, real=True):
128138        """ 
129139
130140    @abc .abstractmethod  
131-     def  ask (self , n , tell_pending = True ):
141+     def  ask (self , n :  int , tell_pending :  bool   =   True ):
132142        """Choose the next 'n' points to evaluate. 
133143
134144        Parameters 
@@ -142,11 +152,11 @@ def ask(self, n, tell_pending=True):
142152        """ 
143153
144154    @abc .abstractmethod  
145-     def  _get_data (self ):
155+     def  _get_data (self )  ->   Any :
146156        pass 
147157
148158    @abc .abstractmethod  
149-     def  _set_data (self ):
159+     def  _set_data (self ,  data :  Any ):
150160        pass 
151161
152162    @abc .abstractmethod  
@@ -164,7 +174,7 @@ def copy_from(self, other):
164174        """ 
165175        self ._set_data (other ._get_data ())
166176
167-     def  save (self , fname , compress = True ):
177+     def  save (self , fname :  str , compress :  bool   =   True )  ->   None :
168178        """Save the data of the learner into a pickle file. 
169179
170180        Parameters 
@@ -178,7 +188,7 @@ def save(self, fname, compress=True):
178188        data  =  self ._get_data ()
179189        save (fname , data , compress )
180190
181-     def  load (self , fname , compress = True ):
191+     def  load (self , fname :  str , compress :  bool   =   True )  ->   None :
182192        """Load the data of a learner from a pickle file. 
183193
184194        Parameters 
@@ -193,8 +203,8 @@ def load(self, fname, compress=True):
193203            data  =  load (fname , compress )
194204            self ._set_data (data )
195205
196-     def  __getstate__ (self ):
206+     def  __getstate__ (self )  ->   bytes :
197207        return  cloudpickle .dumps (self .__dict__ )
198208
199-     def  __setstate__ (self , state ) :
209+     def  __setstate__ (self , state :  bytes )  ->   None :
200210        self .__dict__  =  cloudpickle .loads (state )
0 commit comments