Skip to content

Commit 180ff33

Browse files
committed
[rebase] Rebase to the latest version and merge test_evaluator to train_evaluator
Since test_evaluator can be merged, I merged it. * [rebase] Rebase and merge the changes in non-test files without issues * [refactor] Merge test- and train-evaluator * [fix] Fix the import error due to the change xxx_evaluator --> evaluator * [test] Fix errors in tests * [fix] Fix the handling of test pred in no resampling * [refactor] Move save_y_opt=False for no resampling deepter for simplicity * [test] Increase the budget size for no resample tests * [test] [fix] Rebase, modify tests, and increase the coverage
1 parent 797ce34 commit 180ff33

16 files changed

+465
-1012
lines changed

autoPyTorch/api/base_task.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _get_dataset_input_validator(
315315
Testing feature set
316316
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
317317
Testing target set
318-
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
318+
resampling_strategy (Optional[ResamplingStrategies]):
319319
Strategy to split the training data. if None, uses
320320
HoldoutValTypes.holdout_validation.
321321
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -355,7 +355,7 @@ def get_dataset(
355355
Testing feature set
356356
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
357357
Testing target set
358-
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
358+
resampling_strategy (Optional[ResamplingStrategies]):
359359
Strategy to split the training data. if None, uses
360360
HoldoutValTypes.holdout_validation.
361361
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -973,7 +973,7 @@ def _search(
973973
`SMAC <https://automl.github.io/SMAC3/master/index.html>`_.
974974
tae_func (Optional[Callable]):
975975
TargetAlgorithm to be optimised. If None, `eval_function`
976-
available in autoPyTorch/evaluation/train_evaluator is used.
976+
available in autoPyTorch/evaluation/evaluator is used.
977977
Must be child class of AbstractEvaluator.
978978
all_supported_metrics (bool: default=True):
979979
If True, all metrics supporting current task will be calculated
@@ -1380,7 +1380,7 @@ def fit_pipeline(
13801380
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13811381
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13821382
dataset_name: Optional[str] = None,
1383-
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes]] = None,
1383+
resampling_strategy: Optional[ResamplingStrategies] = None,
13841384
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13851385
run_time_limit_secs: int = 60,
13861386
memory_limit: Optional[int] = None,
@@ -1415,7 +1415,7 @@ def fit_pipeline(
14151415
be provided to track the generalization performance of each stage.
14161416
dataset_name (Optional[str]):
14171417
Name of the dataset, if None, random value is used.
1418-
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
1418+
resampling_strategy (Optional[ResamplingStrategies]):
14191419
Strategy to split the training data. if None, uses
14201420
HoldoutValTypes.holdout_validation.
14211421
resampling_strategy_args (Optional[Dict[str, Any]]):

autoPyTorch/api/tabular_classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def search(
336336
`SMAC <https://automl.github.io/SMAC3/master/index.html>`_.
337337
tae_func (Optional[Callable]):
338338
TargetAlgorithm to be optimised. If None, `eval_function`
339-
available in autoPyTorch/evaluation/train_evaluator is used.
339+
available in autoPyTorch/evaluation/evaluator is used.
340340
Must be child class of AbstractEvaluator.
341341
all_supported_metrics (bool: default=True):
342342
If True, all metrics supporting current task will be calculated

autoPyTorch/api/tabular_regression.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def search(
337337
`SMAC <https://automl.github.io/SMAC3/master/index.html>`_.
338338
tae_func (Optional[Callable]):
339339
TargetAlgorithm to be optimised. If None, `eval_function`
340-
available in autoPyTorch/evaluation/train_evaluator is used.
340+
available in autoPyTorch/evaluation/evaluator is used.
341341
Must be child class of AbstractEvaluator.
342342
all_supported_metrics (bool: default=True):
343343
If True, all metrics supporting current task will be calculated

autoPyTorch/datasets/resampling_strategy.py

+8
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ def is_stratified(self) -> bool:
9393
# TODO: replace it with another way
9494
ResamplingStrategies = Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
9595

96+
97+
def check_resampling_strategy(resampling_strategy: Optional[ResamplingStrategies]) -> None:
98+
choices = (CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes)
99+
if not isinstance(resampling_strategy, choices):
100+
rs_names = (rs.__mro__[0].__name__ for rs in choices)
101+
raise ValueError(f'resampling_strategy must be in {rs_names}, but got {resampling_strategy}')
102+
103+
96104
DEFAULT_RESAMPLING_PARAMETERS: Dict[
97105
ResamplingStrategies,
98106
Dict[str, Any]

autoPyTorch/evaluation/abstract_evaluator.py

+108-138
Original file line numberDiff line numberDiff line change
@@ -167,47 +167,87 @@ class FixedPipelineParams(NamedTuple):
167167
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
168168
An object used to fine tune the hyperparameter search space of the pipeline
169169
"""
170-
def __init__(self, backend: Backend,
171-
queue: Queue,
172-
metric: autoPyTorchMetric,
173-
budget: float,
174-
configuration: Union[int, str, Configuration],
175-
budget_type: str = None,
176-
pipeline_config: Optional[Dict[str, Any]] = None,
177-
seed: int = 1,
178-
output_y_hat_optimization: bool = True,
179-
num_run: Optional[int] = None,
180-
include: Optional[Dict[str, Any]] = None,
181-
exclude: Optional[Dict[str, Any]] = None,
182-
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
183-
init_params: Optional[Dict[str, Any]] = None,
184-
logger_port: Optional[int] = None,
185-
all_supported_metrics: bool = True,
186-
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
187-
) -> None:
188-
189-
self.starttime = time.time()
190-
191-
self.configuration = configuration
192-
self.backend: Backend = backend
193-
self.queue = queue
194-
195-
self.include = include
196-
self.exclude = exclude
197-
self.search_space_updates = search_space_updates
198-
199-
self.metric = metric
200-
201-
202-
self._init_datamanager_info()
203-
204-
# Flag to save target for ensemble
205-
self.output_y_hat_optimization = output_y_hat_optimization
170+
backend: Backend
171+
seed: int
172+
metric: autoPyTorchMetric
173+
budget_type: str # Literal['epochs', 'runtime']
174+
pipeline_config: Dict[str, Any]
175+
save_y_opt: bool = True
176+
include: Optional[Dict[str, Any]] = None
177+
exclude: Optional[Dict[str, Any]] = None
178+
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None
179+
logger_port: Optional[int] = None
180+
all_supported_metrics: bool = True
181+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
182+
183+
@classmethod
184+
def with_default_pipeline_config(
185+
cls,
186+
pipeline_config: Optional[Dict[str, Any]] = None,
187+
choice: str = 'default',
188+
**kwargs: Any
189+
) -> 'FixedPipelineParams':
190+
191+
if 'budget_type' in kwargs:
192+
raise TypeError(
193+
f'{cls.__name__}.with_default_pipeline_config() got multiple values for argument `budget_type`'
194+
)
195+
196+
budget_type_choices = ('epochs', 'runtime')
197+
if pipeline_config is None:
198+
pipeline_config = get_default_pipeline_config(choice=choice)
199+
if 'budget_type' not in pipeline_config:
200+
raise ValueError('pipeline_config must have `budget_type`')
201+
202+
budget_type = pipeline_config['budget_type']
203+
if pipeline_config['budget_type'] not in budget_type_choices:
204+
raise ValueError(f"budget_type must be in {budget_type_choices}, but got {budget_type}")
205+
206+
kwargs.update(pipeline_config=pipeline_config, budget_type=budget_type)
207+
return cls(**kwargs)
208+
209+
210+
class EvaluatorParams(NamedTuple):
211+
"""
212+
Attributes:
213+
configuration (Union[int, str, Configuration]):
214+
Determines the pipeline to be constructed. A dummy estimator is created for
215+
integer configurations, a traditional machine learning pipeline is created
216+
for string based configuration, and NAS is performed when a configuration
217+
object is passed.
218+
num_run (Optional[int]):
219+
An identifier of the current configuration being fit. This number is unique per
220+
configuration.
221+
init_params (Optional[Dict[str, Any]]):
222+
Optional argument that is passed to each pipeline step. It is the equivalent of
223+
kwargs for the pipeline steps.
224+
"""
225+
budget: float
226+
configuration: Union[int, str, Configuration]
227+
num_run: Optional[int] = None
228+
init_params: Optional[Dict[str, Any]] = None
229+
230+
@classmethod
231+
def with_default_budget(
232+
cls,
233+
budget: float = 0,
234+
choice: str = 'default',
235+
**kwargs: Any
236+
) -> 'EvaluatorParams':
237+
budget = get_default_budget(choice=choice) if budget == 0 else budget
238+
kwargs.update(budget=budget)
239+
return cls(**kwargs)
240+
241+
242+
class AbstractEvaluator(object):
243+
"""
244+
This method defines the interface that pipeline evaluators should follow, when
245+
interacting with SMAC through TargetAlgorithmQuery.
206246
207247
An evaluator is an object that:
208248
+ constructs a pipeline (i.e. a classification or regression estimator) for a given
209249
pipeline_config and run settings (budget, seed)
210-
+ Fits and trains this pipeline (TrainEvaluator) or tests a given
250+
+ Fits and trains this pipeline (Evaluator) or tests a given
211251
configuration (TestEvaluator)
212252
213253
The provided configuration determines the type of pipeline created. For more
@@ -244,21 +284,33 @@ def _init_miscellaneous(self) -> None:
244284
DisableFileOutputParameters.check_compatibility(disable_file_output)
245285
self.disable_file_output = disable_file_output
246286
else:
247-
if isinstance(self.configuration, int):
248-
self.pipeline_class = DummyClassificationPipeline
249-
elif isinstance(self.configuration, str):
250-
if self.task_type in TABULAR_TASKS:
251-
self.pipeline_class = MyTraditionalTabularClassificationPipeline
252-
else:
253-
raise ValueError("Only tabular tasks are currently supported with traditional methods")
254-
elif isinstance(self.configuration, Configuration):
255-
if self.task_type in TABULAR_TASKS:
256-
self.pipeline_class = autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline
257-
elif self.task_type in IMAGE_TASKS:
258-
self.pipeline_class = autoPyTorch.pipeline.image_classification.ImageClassificationPipeline
259-
else:
260-
raise ValueError('task {} not available'.format(self.task_type))
261-
self.predict_function = self._predict_proba
287+
self.disable_file_output = []
288+
289+
if self.num_folds == 1: # not save cv model when we perform holdout
290+
self.disable_file_output.append('cv_model')
291+
292+
def _init_dataset_properties(self) -> None:
293+
datamanager: BaseDataset = self.fixed_pipeline_params.backend.load_datamanager()
294+
if datamanager.task_type is None:
295+
raise ValueError(f"Expected dataset {datamanager.__class__.__name__} to have task_type got None")
296+
if datamanager.splits is None:
297+
raise ValueError(f"cannot fit pipeline {self.__class__.__name__} with datamanager.splits None")
298+
299+
self.splits = datamanager.splits
300+
self.num_folds: int = len(self.splits)
301+
# Since cv might not finish in time, we take self.pipelines as None by default
302+
self.pipelines: List[Optional[BaseEstimator]] = [None] * self.num_folds
303+
self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type]
304+
self.num_classes = getattr(datamanager, 'num_classes', 1)
305+
self.output_type = datamanager.output_type
306+
307+
search_space_updates = self.fixed_pipeline_params.search_space_updates
308+
self.dataset_properties = datamanager.get_dataset_properties(
309+
get_dataset_requirements(info=datamanager.get_required_dataset_info(),
310+
include=self.fixed_pipeline_params.include,
311+
exclude=self.fixed_pipeline_params.exclude,
312+
search_space_updates=search_space_updates
313+
))
262314

263315
self.X_train, self.y_train = datamanager.train_tensors
264316
self.unique_train_labels = [
@@ -271,6 +323,8 @@ def _init_miscellaneous(self) -> None:
271323
if datamanager.test_tensors is not None:
272324
self.X_test, self.y_test = datamanager.test_tensors
273325

326+
del datamanager # Delete datamanager to release the memory
327+
274328
def _init_additional_metrics(self) -> None:
275329
all_supported_metrics = self.fixed_pipeline_params.all_supported_metrics
276330
metric = self.fixed_pipeline_params.metric
@@ -282,59 +336,7 @@ def _init_additional_metrics(self) -> None:
282336
all_supported_metrics=all_supported_metrics)
283337
self.metrics_dict = {'additional_metrics': [m.name for m in [metric] + self.additional_metrics]}
284338

285-
def _init_datamanager_info(
286-
self,
287-
) -> None:
288-
"""
289-
Initialises instance attributes that come from the datamanager.
290-
For example,
291-
X_train, y_train, etc.
292-
"""
293-
294-
datamanager: BaseDataset = self.backend.load_datamanager()
295-
296-
assert datamanager.task_type is not None, \
297-
"Expected dataset {} to have task_type got None".format(datamanager.__class__.__name__)
298-
self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type]
299-
self.output_type = STRING_TO_OUTPUT_TYPES[datamanager.output_type]
300-
self.issparse = datamanager.issparse
301-
302-
self.X_train, self.y_train = datamanager.train_tensors
303-
304-
if datamanager.val_tensors is not None:
305-
self.X_valid, self.y_valid = datamanager.val_tensors
306-
else:
307-
self.X_valid, self.y_valid = None, None
308-
309-
if datamanager.test_tensors is not None:
310-
self.X_test, self.y_test = datamanager.test_tensors
311-
else:
312-
self.X_test, self.y_test = None, None
313-
314-
self.resampling_strategy = datamanager.resampling_strategy
315-
316-
self.num_classes: Optional[int] = getattr(datamanager, "num_classes", None)
317-
318-
self.dataset_properties = datamanager.get_dataset_properties(
319-
get_dataset_requirements(info=datamanager.get_required_dataset_info(),
320-
include=self.include,
321-
exclude=self.exclude,
322-
search_space_updates=self.search_space_updates
323-
))
324-
self.splits = datamanager.splits
325-
if self.splits is None:
326-
raise AttributeError(f"create_splits on {datamanager.__class__.__name__} must be called "
327-
f"before the instantiation of {self.__class__.__name__}")
328-
329-
# delete datamanager from memory
330-
del datamanager
331-
332-
def _init_fit_dictionary(
333-
self,
334-
logger_port: int,
335-
pipeline_config: Dict[str, Any],
336-
metrics_dict: Optional[Dict[str, List[str]]] = None,
337-
) -> None:
339+
def _init_fit_dictionary(self) -> None:
338340
"""
339341
Initialises the fit dictionary
340342
@@ -617,36 +619,4 @@ def _is_output_possible(
617619
if y is not None and not np.all(np.isfinite(y)):
618620
return False # Model predictions contains NaNs
619621

620-
Args:
621-
prediction (np.ndarray):
622-
The un-formatted predictions of a pipeline
623-
Y_train (np.ndarray):
624-
The labels from the dataset to give an intuition of the expected
625-
predictions dimensionality
626-
Returns:
627-
(np.ndarray):
628-
The formatted prediction
629-
"""
630-
assert self.num_classes is not None, "Called function on wrong task"
631-
632-
if self.output_type == MULTICLASS and \
633-
prediction.shape[1] < self.num_classes:
634-
if Y_train is None:
635-
raise ValueError('Y_train must not be None!')
636-
classes = list(np.unique(Y_train))
637-
638-
mapping = dict()
639-
for class_number in range(self.num_classes):
640-
if class_number in classes:
641-
index = classes.index(class_number)
642-
mapping[index] = class_number
643-
new_predictions = np.zeros((prediction.shape[0], self.num_classes),
644-
dtype=np.float32)
645-
646-
for index in mapping:
647-
class_index = mapping[index]
648-
new_predictions[:, class_index] = prediction[:, index]
649-
650-
return new_predictions
651-
652-
return prediction
622+
return True

0 commit comments

Comments
 (0)