-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
After doing some research, it looks like we can re-use a bit of what is already available through BoTorch to simplify a lot of the math in the LatentKernel implementation. I'm writing up unit tests for this module in this PR so we can safely refactor, if possible: #262
Something like the following should work:
import math
import torch
import gpytorch
from gpytorch.kernels import Kernel
from gpytorch.constraints import Interval
from botorch.models.utils.gpytorch_modules import get_matern_kernel_with_gamma_prior
class RotatedInputsKernel(Kernel):
"""Applies a learned orthogonal rotation (blockwise if desired), then delegates to a base kernel."""
is_stationary = True
has_lengthscale = False # lengthscale lives in the base kernel
def __init__(
self,
d: int,
batch_shape: torch.Size = torch.Size(),
skew_dims: bool | list[tuple[int, ...]] = True,
):
super().__init__(batch_shape=batch_shape)
self.d = d
# Choose groups of dims that can rotate among themselves
if isinstance(skew_dims, bool):
self.groups = [tuple(range(d))] if skew_dims else []
else:
self.groups = [tuple(g) for g in skew_dims]
# Base kernel = ScaleKernel(MaternKernel(...)) with your Gamma priors
self.base = get_matern_kernel_with_gamma_prior(ard_num_dims=d, batch_shape=batch_shape)
# One unconstrained matrix per group; we’ll skew-symmetrize then expm -> orthogonal
self.raw_group_mats = torch.nn.ParameterList()
for g in self.groups:
k = len(g)
raw = torch.zeros(*batch_shape, k, k, dtype=torch.double)
self.raw_group_mats.append(torch.nn.Parameter(raw))
# Optional: mimic your [-2π, 2π] bounding (can be omitted)
for i in range(len(self.raw_group_mats)):
self.register_constraint(f"raw_group_mats.{i}", Interval(-2 * math.pi, 2 * math.pi))
def _rotation(self) -> torch.Tensor:
# Start with identity, then fill block rotations
R = torch.eye(self.d, dtype=torch.double, device=self.raw_group_mats[0].device if self.groups else None)
R = R.expand(*self.batch_shape, self.d, self.d).clone() if self.batch_shape else R
for i, g in enumerate(self.groups):
A = getattr(self, f"raw_group_mats.{i}_constraint").transform(self.raw_group_mats[i])
S = A - A.transpose(-1, -2)
Q = torch.linalg.matrix_exp(S) # orthogonal
idx = torch.tensor(g, device=Q.device)
R[..., idx[:, None], idx[None, :]] = Q
return R
def _transform(self, X: torch.Tensor) -> torch.Tensor:
# X is (..., n, d). Right-multiply by R^T to rotate features.
R = self._rotation()
return torch.matmul(X, R.transpose(-1, -2))
def forward(self, x1, x2, diag=False, **params):
return self.base(self._transform(x1), self._transform(x2), diag=diag, **params)And then you can use it like:
import gpytorch
from gpytorch.priors import NormalPrior
from gpytorch.means import ConstantMean
from botorch.models.gp_regression import SingleTaskGP
def make_latent_gp(train_X, train_Y, skew_dims=True, **kwargs):
m = train_Y.shape[-1]
aug_batch_shape = train_X.shape[:-2] + (torch.Size([m]) if m > 1 else torch.Size())
mean = ConstantMean(batch_shape=aug_batch_shape, constant_prior=NormalPrior(0.0, 1.0))
covar = RotatedInputsKernel(d=train_X.shape[-1], batch_shape=aug_batch_shape, skew_dims=skew_dims)
return SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
mean_module=mean,
covar_module=covar,
outcome_transform=None, # keep if you truly need it
**kwargs,
)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels