7
7
8
8
from smac .tae import StatusType
9
9
10
- from autoPyTorch .automl_common . common . utils . backend import Backend
11
- from autoPyTorch . constants import (
12
- CLASSIFICATION_TASKS ,
13
- MULTICLASSMULTIOUTPUT ,
10
+ from autoPyTorch .datasets . resampling_strategy import (
11
+ CrossValTypes ,
12
+ NoResamplingStrategyTypes ,
13
+ check_resampling_strategy
14
14
)
15
- from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
16
15
from autoPyTorch .evaluation .abstract_evaluator import (
17
16
AbstractEvaluator ,
18
17
EvaluationResults ,
21
20
from autoPyTorch .evaluation .abstract_evaluator import EvaluatorParams , FixedPipelineParams
22
21
from autoPyTorch .utils .common import dict_repr , subsampler
23
22
24
- __all__ = ['TrainEvaluator' , 'eval_train_function' ]
23
+ __all__ = ['Evaluator' , 'eval_fn' ]
24
+
25
25
26
26
class _CrossValidationResultsManager :
27
27
def __init__ (self , num_folds : int ):
@@ -83,15 +83,13 @@ def get_result_dict(self) -> Dict[str, Any]:
83
83
)
84
84
85
85
86
- class TrainEvaluator (AbstractEvaluator ):
86
+ class Evaluator (AbstractEvaluator ):
87
87
"""
88
88
This class builds a pipeline using the provided configuration.
89
89
A pipeline implementing the provided configuration is fitted
90
90
using the datamanager object retrieved from disc, via the backend.
91
91
After the pipeline is fitted, it is save to disc and the performance estimate
92
- is communicated to the main process via a Queue. It is only compatible
93
- with `CrossValTypes`, `HoldoutValTypes`, i.e, when the training data
94
- is split and the validation set is used for SMBO optimisation.
92
+ is communicated to the main process via a Queue.
95
93
96
94
Args:
97
95
queue (Queue):
@@ -101,54 +99,27 @@ class TrainEvaluator(AbstractEvaluator):
101
99
Fixed parameters for a pipeline
102
100
evaluator_params (EvaluatorParams):
103
101
The parameters for an evaluator.
102
+
103
+ Attributes:
104
+ train (bool):
105
+ Whether the training data is split and the validation set is used for SMBO optimisation.
106
+ cross_validation (bool):
107
+ Whether we use cross validation or not.
104
108
"""
105
- def __init__ (self , backend : Backend , queue : Queue ,
106
- metric : autoPyTorchMetric ,
107
- budget : float ,
108
- configuration : Union [int , str , Configuration ],
109
- budget_type : str = None ,
110
- pipeline_config : Optional [Dict [str , Any ]] = None ,
111
- seed : int = 1 ,
112
- output_y_hat_optimization : bool = True ,
113
- num_run : Optional [int ] = None ,
114
- include : Optional [Dict [str , Any ]] = None ,
115
- exclude : Optional [Dict [str , Any ]] = None ,
116
- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
117
- init_params : Optional [Dict [str , Any ]] = None ,
118
- logger_port : Optional [int ] = None ,
119
- keep_models : Optional [bool ] = None ,
120
- all_supported_metrics : bool = True ,
121
- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ) -> None :
122
- super ().__init__ (
123
- backend = backend ,
124
- queue = queue ,
125
- configuration = configuration ,
126
- metric = metric ,
127
- seed = seed ,
128
- output_y_hat_optimization = output_y_hat_optimization ,
129
- num_run = num_run ,
130
- include = include ,
131
- exclude = exclude ,
132
- disable_file_output = disable_file_output ,
133
- init_params = init_params ,
134
- budget = budget ,
135
- budget_type = budget_type ,
136
- logger_port = logger_port ,
137
- all_supported_metrics = all_supported_metrics ,
138
- pipeline_config = pipeline_config ,
139
- search_space_updates = search_space_updates
140
- )
109
+ def __init__ (self , queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ):
110
+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
111
+ self .train = not isinstance (resampling_strategy , NoResamplingStrategyTypes )
112
+ self .cross_validation = isinstance (resampling_strategy , CrossValTypes )
141
113
142
- if not isinstance ( self .datamanager . resampling_strategy , ( CrossValTypes , HoldoutValTypes )) :
143
- resampling_strategy = self . datamanager . resampling_strategy
144
- raise ValueError (
145
- f'resampling_strategy for TrainEvaluator must be in '
146
- f'(CrossValTypes, HoldoutValTypes), but got { resampling_strategy } '
147
- )
114
+ if not self .train and fixed_pipeline_params . save_y_opt :
115
+ # TODO: Add the test to cover here
116
+ # No resampling can not be used for building ensembles. save_y_opt=False ensures it
117
+ fixed_pipeline_params = fixed_pipeline_params . _replace ( save_y_opt = False )
118
+
119
+ super (). __init__ ( queue = queue , fixed_pipeline_params = fixed_pipeline_params , evaluator_params = evaluator_params )
148
120
149
- self .splits = self .datamanager .splits
150
- self .num_folds : int = len (self .splits )
151
- self .logger .debug ("Search space updates :{}" .format (self .search_space_updates ))
121
+ if self .train :
122
+ self .logger .debug ("Search space updates :{}" .format (self .fixed_pipeline_params .search_space_updates ))
152
123
153
124
def _evaluate_on_split (self , split_id : int ) -> EvaluationResults :
154
125
"""
@@ -177,7 +148,7 @@ def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
177
148
178
149
return EvaluationResults (
179
150
pipeline = pipeline ,
180
- opt_loss = self ._loss (labels = self .y_train [opt_split ], preds = opt_pred ),
151
+ opt_loss = self ._loss (labels = self .y_train [opt_split ] if self . train else self . y_test , preds = opt_pred ),
181
152
train_loss = self ._loss (labels = self .y_train [train_split ], preds = train_pred ),
182
153
opt_pred = opt_pred ,
183
154
valid_pred = valid_pred ,
@@ -203,6 +174,7 @@ def _cross_validation(self) -> EvaluationResults:
203
174
results = self ._evaluate_on_split (split_id )
204
175
205
176
self .pipelines [split_id ] = results .pipeline
177
+ assert opt_split is not None # mypy redefinition
206
178
cv_results .update (split_id , results , len (train_split ), len (opt_split ))
207
179
208
180
self .y_opt = np .concatenate ([y_opt for y_opt in Y_opt if y_opt is not None ])
@@ -214,15 +186,16 @@ def evaluate_loss(self) -> None:
214
186
if self .splits is None :
215
187
raise ValueError (f"cannot fit pipeline { self .__class__ .__name__ } with datamanager.splits None" )
216
188
217
- if self .num_folds == 1 :
189
+ if self .cross_validation :
190
+ results = self ._cross_validation ()
191
+ else :
218
192
_ , opt_split = self .splits [0 ]
219
193
results = self ._evaluate_on_split (split_id = 0 )
220
- self .y_opt , self .pipelines [0 ] = self .y_train [opt_split ], results .pipeline
221
- else :
222
- results = self ._cross_validation ()
194
+ self .pipelines [0 ] = results .pipeline
195
+ self .y_opt = self .y_train [opt_split ] if self .train else self .y_test
223
196
224
197
self .logger .debug (
225
- f"In train evaluator. evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
198
+ f"In evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
226
199
f" status: { results .status } ,\n additional run info:\n { dict_repr (results .additional_run_info )} "
227
200
)
228
201
self .record_evaluation (results = results )
@@ -242,41 +215,23 @@ def _fit_and_evaluate_loss(
242
215
243
216
kwargs = {'pipeline' : pipeline , 'unique_train_labels' : self .unique_train_labels [split_id ]}
244
217
train_pred = self .predict (subsampler (self .X_train , train_indices ), ** kwargs )
245
- opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs )
246
- valid_pred = self .predict (self .X_valid , ** kwargs )
247
218
test_pred = self .predict (self .X_test , ** kwargs )
219
+ valid_pred = self .predict (self .X_valid , ** kwargs )
220
+
221
+ # No resampling ===> evaluate on test dataset
222
+ opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs ) if self .train else test_pred
248
223
249
224
assert train_pred is not None and opt_pred is not None # mypy check
250
225
return train_pred , opt_pred , valid_pred , test_pred
251
226
252
227
253
- # create closure for evaluating an algorithm
254
- def eval_train_function (
255
- backend : Backend ,
256
- queue : Queue ,
257
- metric : autoPyTorchMetric ,
258
- budget : float ,
259
- config : Optional [Configuration ],
260
- seed : int ,
261
- output_y_hat_optimization : bool ,
262
- num_run : int ,
263
- include : Optional [Dict [str , Any ]],
264
- exclude : Optional [Dict [str , Any ]],
265
- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
266
- pipeline_config : Optional [Dict [str , Any ]] = None ,
267
- budget_type : str = None ,
268
- init_params : Optional [Dict [str , Any ]] = None ,
269
- logger_port : Optional [int ] = None ,
270
- all_supported_metrics : bool = True ,
271
- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
272
- instance : str = None ,
273
- ) -> None :
228
+ def eval_fn (queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ) -> None :
274
229
"""
275
230
This closure allows the communication between the TargetAlgorithmQuery and the
276
- pipeline trainer (TrainEvaluator ).
231
+ pipeline trainer (Evaluator ).
277
232
278
233
Fundamentally, smac calls the TargetAlgorithmQuery.run() method, which internally
279
- builds a TrainEvaluator . The TrainEvaluator builds a pipeline, stores the output files
234
+ builds an Evaluator . The Evaluator builds a pipeline, stores the output files
280
235
to disc via the backend, and puts the performance result of the run in the queue.
281
236
282
237
Args:
@@ -288,7 +243,11 @@ def eval_train_function(
288
243
evaluator_params (EvaluatorParams):
289
244
The parameters for an evaluator.
290
245
"""
291
- evaluator = TrainEvaluator (
246
+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
247
+ check_resampling_strategy (resampling_strategy )
248
+
249
+ # NoResamplingStrategyTypes ==> test evaluator, otherwise ==> train evaluator
250
+ evaluator = Evaluator (
292
251
queue = queue ,
293
252
evaluator_params = evaluator_params ,
294
253
fixed_pipeline_params = fixed_pipeline_params
0 commit comments