Skip to content

Commit 68e334d

Browse files
sdaultonfacebook-github-bot
authored andcommitted
be explicit about task feature in TL (#2918)
Summary: Pull Request resolved: #2918 X-link: facebook/Ax#4021 Previously, the TL adapter added a task feature to `SearchSpaceDigest.task_features`, but not to `SearchSpaceDigest.bounds` or `SearchSpaceDigest.feature_names`. This avoided needing to pass the task feature as a fixed_feature when making predictions or generating. However, there was a bug when using `Normalize` that dropped one non-task feature here: https://github.com/facebook/Ax/blob/main/ax/generators/torch/botorch_modular/surrogate.py#L247-L250 This diff adds the task feature to via a new Transform. This significantly simplifies task feature handling across Adapter methods. Reviewed By: saitcakmak Differential Revision: D77905821
1 parent b9ffefd commit 68e334d

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

botorch/utils/datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def __init__(
7676
self._X = X
7777
self._Y = Y
7878
self._Yvar = Yvar
79-
self.feature_names = feature_names
80-
self.outcome_names = outcome_names
79+
self.feature_names = feature_names.copy()
80+
self.outcome_names = outcome_names.copy()
8181
self.group_indices = group_indices
8282
self.validate_init = validate_init
8383
if validate_init:
@@ -351,7 +351,7 @@ def __init__(
351351
self.target_outcome_name = target_outcome_name
352352
self.task_feature_index = task_feature_index
353353
self._validate_datasets(datasets=datasets)
354-
self.feature_names = self.datasets[target_outcome_name].feature_names
354+
self.feature_names = self.datasets[target_outcome_name].feature_names.copy()
355355
self.outcome_names = [target_outcome_name]
356356

357357
# Check if the datasets have identical feature sets.

test/utils/test_datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def test_clone(self, supervised: bool = True) -> None:
235235
self.assertIs(dataset._X, dataset2._X)
236236
self.assertIs(dataset._Y, dataset2._Y)
237237
self.assertIs(dataset._Yvar, dataset2._Yvar)
238-
self.assertIs(dataset.feature_names, dataset2.feature_names)
239-
self.assertIs(dataset.outcome_names, dataset2.outcome_names)
238+
self.assertEqual(dataset.feature_names, dataset2.feature_names)
239+
self.assertEqual(dataset.outcome_names, dataset2.outcome_names)
240240
# test with mask
241241
mask = torch.tensor([0, 1, 1], dtype=torch.bool)
242242
if supervised:

0 commit comments

Comments
 (0)