Skip to content

Commit 888a7db

Browse files
committed
autoPyTorch/api/
1 parent 707b275 commit 888a7db

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

test/test_api/test_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def test_tabular_input_support(openml_id, backend):
429429
estimator = TabularClassificationTask(
430430
backend=backend,
431431
resampling_strategy=HoldoutValTypes.holdout_validation,
432-
ensemble_size=0,
433432
)
434433

435434
estimator._do_dummy_prediction = unittest.mock.MagicMock()
@@ -444,6 +443,7 @@ def test_tabular_input_support(openml_id, backend):
444443
func_eval_time_limit_secs=50,
445444
enable_traditional_pipeline=False,
446445
load_models=False,
446+
ensemble_size=0,
447447
)
448448

449449

@@ -453,7 +453,6 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
453453
estimator = TabularClassificationTask(
454454
backend=backend,
455455
resampling_strategy=HoldoutValTypes.holdout_validation,
456-
ensemble_size=0,
457456
)
458457

459458
# Setup pre-requisites normally set by search()

test/test_api/test_base_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_set_pipeline_config():
118118
])
119119
def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, budget_type, expected):
120120
BaseTask.__abstractmethods__ = set()
121-
estimator = BaseTask(task_type='tabular_classification', ensemble_size=0)
121+
estimator = BaseTask(task_type='tabular_classification')
122122

123123
# Fixture pipeline config
124124
default_pipeline_config = {
@@ -141,7 +141,7 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud
141141
smac_mock.return_value = smac
142142
estimator._search(optimize_metric='accuracy', dataset=dataset, tae_func=pipeline_fit,
143143
min_budget=min_budget, max_budget=max_budget, budget_type=budget_type,
144-
enable_traditional_pipeline=False,
144+
ensemble_size=0, enable_traditional_pipeline=False,
145145
total_walltime_limit=20, func_eval_time_limit_secs=10,
146146
load_models=False)
147147
assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config
@@ -167,7 +167,6 @@ def test_init_ensemble_builder(backend):
167167
BaseTask.__abstractmethods__ = set()
168168
estimator = BaseTask(
169169
backend=backend,
170-
ensemble_size=0,
171170
)
172171

173172
# Setup pre-requisites normally set by search()

0 commit comments

Comments
 (0)