Skip to content

Commit d6bb8c8

Browse files
committed
add preprocessed_dtype to determine double or float
1 parent 396ff54 commit d6bb8c8

File tree

7 files changed

+14
-12
lines changed

7 files changed

+14
-12
lines changed

autoPyTorch/api/base_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def build_pipeline(
270270
include_components: Optional[Dict[str, Any]] = None,
271271
exclude_components: Optional[Dict[str, Any]] = None,
272272
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
273+
) -> BasePipeline:
273274
"""
274275
Build pipeline according to current task
275276
Characteristics of the dataset to guide the pipeline

autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
4242
# We need to also save the preprocess transforms for inference
4343
X.update({
4444
'preprocess_transforms': transforms,
45-
'shape_after_preprocessing': X['X_train'].shape[1:]
45+
'shape_after_preprocessing': X['X_train'].shape[1:],
46+
'preprocessed_dtype': X['X_train'].dtype.name
4647
})
4748
return X
4849

autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
6565
# We need to also save the preprocess transforms for inference
6666
X.update({
6767
'preprocess_transforms': transforms,
68-
'shape_after_preprocessing': X['X_train'].shape[1:]
68+
'shape_after_preprocessing': X['X_train'].shape[1:],
69+
'preprocessed_dtype': X['X_train'].dtype.name
6970
})
7071
return X
7172

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,15 +447,18 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
447447
raise RuntimeError("Budget exhausted without finishing an epoch.")
448448

449449
if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
450+
use_double = 'float64' in X['preprocessed_dtype']
450451

451452
# update batch norm statistics
452-
swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double())
453-
453+
swa_model = self.choice.swa_model.double() if use_double else self.choice.swa_model
454+
swa_utils.update_bn(loader=X['train_data_loader'], model=swa_model)
454455
# change model
455456
update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict())
456-
if self.choice.use_snapshot_ensemble and len(self.choice.model_snapshots) > 0:
457+
if self.choice.use_snapshot_ensemble:
457458
# we update only the last network which pertains to the stochastic weight averaging model
458-
swa_utils.update_bn(X['train_data_loader'], self.choice.model_snapshots[-1].double())
459+
snapshot_model = self.choice.model_snapshots[-1].double() if use_double else self.choice.model_snapshots[-1]
460+
swa_utils.update_bn(X['train_data_loader'], snapshot_model)
461+
update_model_state_dict_from_swa(X['network_snapshots'][-1], self.choice.swa_model.state_dict())
459462

460463
# wrap up -- add score if not evaluating every epoch
461464
if not self.eval_valid_each_epoch(X):

test/test_pipeline/components/training/test_image_data_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def test_imageloader_build_transform():
1616

1717
fit_dictionary = dict()
1818
fit_dictionary['dataset_properties'] = dict()
19-
fit_dictionary['dataset_properties']['is_small_preprocess'] = unittest.mock.Mock(())
2019
fit_dictionary['image_augmenter'] = unittest.mock.Mock()
2120
fit_dictionary['preprocess_transforms'] = unittest.mock.Mock()
2221

test/test_pipeline/components/training/test_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_fit_transform(self):
101101
'y_train': np.array([0, 1, 0]),
102102
'train_indices': [0, 1],
103103
'val_indices': [2],
104-
'dataset_properties': {'is_small_preprocess': True},
104+
'dataset_properties': {},
105105
'working_dir': '/tmp',
106106
'split_id': 0,
107107
'backend': backend,

test/test_pipeline/test_tabular_classification.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,12 @@ def test_pipeline_transform(self, fit_dictionary_tabular, exclude):
205205
# We expect the transformations to be in the pipeline at anytime for inference
206206
assert 'preprocess_transforms' in transformed_fit_dictionary_tabular.keys()
207207

208-
@pytest.mark.parametrize("is_small_preprocess", [True, False])
209-
def test_default_configuration(self, fit_dictionary_tabular, is_small_preprocess, exclude):
208+
def test_default_configuration(self, fit_dictionary_tabular, exclude):
210209
"""Makes sure that when no config is set, we can trust the
211210
default configuration from the space"""
212211

213212
fit_dictionary_tabular['epochs'] = 5
214213

215-
fit_dictionary_tabular['is_small_preprocess'] = is_small_preprocess
216-
217214
pipeline = TabularClassificationPipeline(
218215
dataset_properties=fit_dictionary_tabular['dataset_properties'],
219216
exclude=exclude)

0 commit comments

Comments
 (0)