Skip to content

Commit 8f8dee1

Browse files
committed
fix flake and issue #299
1 parent fd71cb7 commit 8f8dee1

File tree

3 files changed

+63
-69
lines changed

3 files changed

+63
-69
lines changed

autoPyTorch/api/base_task.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(
211211
self._scoring_functions: Optional[List[autoPyTorchMetric]] = None
212212
self._logger: Optional[PicklableClientLogger] = None
213213
self.dataset_name: Optional[str] = None
214+
self.dataset = Optional[BaseDataset]
214215
self.cv_models_: Dict = {}
215216

216217
self._results_manager = ResultsManager()
@@ -684,20 +685,7 @@ def _load_best_individual_model(self) -> SingleBest:
684685
run_history=self.run_history,
685686
backend=self._backend,
686687
)
687-
if self._logger is None:
688-
warnings.warn(
689-
"No valid ensemble was created. Please check the log"
690-
"file for errors. Default to the best individual estimator:{}".format(
691-
ensemble.identifiers_
692-
)
693-
)
694-
else:
695-
self._logger.exception(
696-
"No valid ensemble was created. Please check the log"
697-
"file for errors. Default to the best individual estimator:{}".format(
698-
ensemble.identifiers_
699-
)
700-
)
688+
701689

702690
return ensemble
703691

@@ -1340,7 +1328,6 @@ def _search(
13401328
if proc_ensemble is not None:
13411329
self._collect_results_ensemble(proc_ensemble)
13421330

1343-
13441331
self._logger.info("Closing the dask infrastructure")
13451332
self._close_dask_client()
13461333
self._logger.info("Finished closing the dask infrastructure")
@@ -1350,6 +1337,14 @@ def _search(
13501337
self._load_models()
13511338
self._logger.info("Finished loading models...")
13521339

1340+
if isinstance(self.ensemble_, SingleBest) and ensemble_size > 0:
1341+
self._logger.exception(
1342+
"No valid ensemble was created. Please check the log"
1343+
"file for errors. Default to the best individual estimator:{}".format(
1344+
self.ensemble_.identifiers_
1345+
)
1346+
)
1347+
13531348
self._cleanup()
13541349

13551350
return self

examples/40_advanced/example_posthoc_ensemble_fit.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,59 @@
2424
from autoPyTorch.api.tabular_classification import TabularClassificationTask
2525

2626

27-
if __name__ == '__main__':
27+
############################################################################
28+
# Data Loading
29+
# ============
30+
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
31+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
32+
X,
33+
y,
34+
random_state=42,
35+
)
2836

29-
############################################################################
30-
# Data Loading
31-
# ============
32-
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
33-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
34-
X,
35-
y,
36-
random_state=42,
37-
)
37+
############################################################################
38+
# Build and fit a classifier
39+
# ==========================
40+
api = TabularClassificationTask(
41+
seed=42,
42+
)
3843

39-
############################################################################
40-
# Build and fit a classifier
41-
# ==========================
42-
api = TabularClassificationTask(
43-
ensemble_size=0,
44-
seed=42,
45-
)
44+
############################################################################
45+
# Search for the best neural network
46+
# ==================================
47+
api.search(
48+
X_train=X_train,
49+
y_train=y_train,
50+
X_test=X_test.copy(),
51+
y_test=y_test.copy(),
52+
optimize_metric='accuracy',
53+
total_walltime_limit=100,
54+
func_eval_time_limit_secs=50,
55+
ensemble_size=0,
56+
)
4657

47-
############################################################################
48-
# Search for the best neural network
49-
# ==================================
50-
api.search(
51-
X_train=X_train,
52-
y_train=y_train,
53-
X_test=X_test.copy(),
54-
y_test=y_test.copy(),
55-
optimize_metric='accuracy',
56-
total_walltime_limit=100,
57-
func_eval_time_limit_secs=50
58-
)
58+
############################################################################
59+
# Print the final performance of the incumbent neural network
60+
# ===========================================================
61+
print(api.run_history, api.trajectory)
62+
y_pred = api.predict(X_test)
63+
score = api.score(y_pred, y_test)
64+
print(score)
5965

60-
############################################################################
61-
# Print the final performance of the incumbent neural network
62-
# ===========================================================
63-
print(api.run_history, api.trajectory)
64-
y_pred = api.predict(X_test)
65-
score = api.score(y_pred, y_test)
66-
print(score)
66+
############################################################################
67+
# Fit an ensemble with the neural networks fitted during the search
68+
# =================================================================
6769

68-
############################################################################
69-
# Fit an ensemble with the neural networks fitted during the search
70-
# =================================================================
70+
api.fit_ensemble(ensemble_size=5,
71+
# Set the enable_traditional_pipeline=True
72+
# to also include traditional models
73+
# in the ensemble
74+
enable_traditional_pipeline=False)
75+
# Print the final ensemble built by AutoPyTorch
76+
y_pred = api.predict(X_test)
77+
score = api.score(y_pred, y_test)
78+
print(score)
79+
print(api.show_models())
7180

72-
api.fit_ensemble(ensemble_size=5,
73-
# Set the enable_traditional_pipeline=True
74-
# to also include traditional models
75-
# in the ensemble
76-
enable_traditional_pipeline=False)
77-
# Print the final ensemble built by AutoPyTorch
78-
y_pred = api.predict(X_test)
79-
score = api.score(y_pred, y_test)
80-
print(score)
81-
print(api.show_models())
82-
api._cleanup()
81+
# Print statistics from search
82+
print(api.sprint_statistics())

test/test_api/test_base_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from autoPyTorch.datasets.base_dataset import BaseDataset
1616
from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes
1717
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
18-
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1918
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
19+
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
2020

2121

2222
# ====
@@ -225,13 +225,12 @@ def test_init_ensemble_builder(backend):
225225
time_left_for_ensembles=60,
226226
optimize_metric='accuracy',
227227
ensemble_nbest=10,
228-
ensemble_size=5
229-
)
228+
ensemble_size=5)
230229

231230
assert isinstance(proc_ensemble, EnsembleBuilderManager)
232231
assert proc_ensemble.opt_metric == 'accuracy'
233232
assert proc_ensemble.metrics[0] == accuracy
234233

235234
estimator._cleanup()
236235

237-
del estimator
236+
del estimator

0 commit comments

Comments
 (0)