Skip to content

[Bug]: condition_on_observations with input transforms #3033

@jduerholt

Description

@jduerholt

What happened?

condition _on_observation fails in the current main branch when an input transform is involved that changes the shape of the training data like NumericToCategoricalEncoding. The problem happens when the training data is appended to the fantasized model. The current code tries to append the transformed X_condition to _original_train_inputs which holds the untransformed training data. For me this does not look correct.

@hvarfner can you have a quick look on this? You made some changes regarding this functionality in #2990

Best,

Johannes

Please provide a minimal, reproducible example of the unexpected behavior.

from botorch.models.transforms.input import Normalize, NumericToCategoricalEncoding
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood 
from functools import partial
from torch.nn.functional import one_hot
import torch


X = torch.rand(12, 2, dtype=torch.double) * 2
Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True)  # explicit output dimension
Y += 0.1 * torch.rand_like(Y)

new_col = torch.randint(0, 3, (X.shape[0], 1), dtype=torch.double)
X = torch.cat([X, new_col], dim=1)

train_X = X[:10]
train_Y = Y[:10]

condition_X = X[10:]
condition_Y = Y[10:]

# Append a column with random integers 0, 1, or 2 to train_X
tf = NumericToCategoricalEncoding(
                dim=3,
                categorical_features = {2: 3},
                encoders={2: partial(one_hot, num_classes=3)},
            )
gp = SingleTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    input_transform=tf,
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

gp.posterior(train_X)


gp.condition_on_observations(condition_X, condition_Y)

Please paste any relevant traceback/logs produced by the example provided.

BoTorch Version

current main

Python Version

No response

Operating System

No response

(Optional) Describe any potential fixes you've considered to the issue outlined above.

From my perspective, the untransformed X_condition needs to be appended here:

if hasattr(fantasy_model, "input_transform"):

What do you think?

Pull Request

Yes

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions