diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index e88c8eaca..e20cc6833 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -14,7 +14,7 @@ CLASSIFICATION_TASKS, MULTICLASSMULTIOUTPUT ) -from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes +from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes from autoPyTorch.evaluation.abstract_evaluator import ( AbstractEvaluator, fit_and_suppress_warnings @@ -153,10 +153,10 @@ def __init__(self, backend: Backend, queue: Queue, search_space_updates=search_space_updates ) - if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)): + if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes)): raise ValueError( f'resampling_strategy for TrainEvaluator must be in ' - f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}' + f'(CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes), but got {self.resampling_strategy}' ) self.num_folds: int = len(self.splits)