-
Notifications
You must be signed in to change notification settings - Fork 446
Description
What happened?
condition _on_observation fails in the current main branch when an input transform is involved that changes the shape of the training data like NumericToCategoricalEncoding. The problem happens when the training data is appended to the fantasized model. The current code tries to append the transformed X_condition to _original_train_inputs which holds the untransformed training data. For me this does not look correct.
@hvarfner can you have a quick look on this? You made some changes regarding this functionality in #2990
Best,
Johannes
Please provide a minimal, reproducible example of the unexpected behavior.
from botorch.models.transforms.input import Normalize, NumericToCategoricalEncoding
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from functools import partial
from torch.nn.functional import one_hot
import torch
X = torch.rand(12, 2, dtype=torch.double) * 2
Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True) # explicit output dimension
Y += 0.1 * torch.rand_like(Y)
new_col = torch.randint(0, 3, (X.shape[0], 1), dtype=torch.double)
X = torch.cat([X, new_col], dim=1)
train_X = X[:10]
train_Y = Y[:10]
condition_X = X[10:]
condition_Y = Y[10:]
# Append a column with random integers 0, 1, or 2 to train_X
tf = NumericToCategoricalEncoding(
dim=3,
categorical_features = {2: 3},
encoders={2: partial(one_hot, num_classes=3)},
)
gp = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=tf,
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)
gp.posterior(train_X)
gp.condition_on_observations(condition_X, condition_Y)Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
current main
Python Version
No response
Operating System
No response
(Optional) Describe any potential fixes you've considered to the issue outlined above.
From my perspective, the untransformed X_condition needs to be appended here:
botorch/botorch/models/gpytorch.py
Line 270 in 48e4ab5
| if hasattr(fantasy_model, "input_transform"): |
What do you think?
Pull Request
Yes
Code of Conduct
- I agree to follow BoTorch's Code of Conduct