Skip to content

Simplify LatentKernel #263

@thopkins32

Description

@thopkins32

@thomaswmorris

After doing some research, it looks like we can re-use a bit of what is already available through BoTorch to simplify a lot of the math in the LatentKernel implementation. I'm writing up unit tests for this module in this PR so we can safely refactor, if possible: #262

Something like the following should work:

import math
import torch
import gpytorch
from gpytorch.kernels import Kernel
from gpytorch.constraints import Interval
from botorch.models.utils.gpytorch_modules import get_matern_kernel_with_gamma_prior


class RotatedInputsKernel(Kernel):
    """Applies a learned orthogonal rotation (blockwise if desired), then delegates to a base kernel."""
    is_stationary = True
    has_lengthscale = False  # lengthscale lives in the base kernel

    def __init__(
        self,
        d: int,
        batch_shape: torch.Size = torch.Size(),
        skew_dims: bool | list[tuple[int, ...]] = True,
    ):
        super().__init__(batch_shape=batch_shape)

        self.d = d
        # Choose groups of dims that can rotate among themselves
        if isinstance(skew_dims, bool):
            self.groups = [tuple(range(d))] if skew_dims else []
        else:
            self.groups = [tuple(g) for g in skew_dims]

        # Base kernel = ScaleKernel(MaternKernel(...)) with your Gamma priors
        self.base = get_matern_kernel_with_gamma_prior(ard_num_dims=d, batch_shape=batch_shape)

        # One unconstrained matrix per group; we’ll skew-symmetrize then expm -> orthogonal
        self.raw_group_mats = torch.nn.ParameterList()
        for g in self.groups:
            k = len(g)
            raw = torch.zeros(*batch_shape, k, k, dtype=torch.double)
            self.raw_group_mats.append(torch.nn.Parameter(raw))

        # Optional: mimic your [-2π, 2π] bounding (can be omitted)
        for i in range(len(self.raw_group_mats)):
            self.register_constraint(f"raw_group_mats.{i}", Interval(-2 * math.pi, 2 * math.pi))

    def _rotation(self) -> torch.Tensor:
        # Start with identity, then fill block rotations
        R = torch.eye(self.d, dtype=torch.double, device=self.raw_group_mats[0].device if self.groups else None)
        R = R.expand(*self.batch_shape, self.d, self.d).clone() if self.batch_shape else R

        for i, g in enumerate(self.groups):
            A = getattr(self, f"raw_group_mats.{i}_constraint").transform(self.raw_group_mats[i])
            S = A - A.transpose(-1, -2)
            Q = torch.linalg.matrix_exp(S)  # orthogonal
            idx = torch.tensor(g, device=Q.device)
            R[..., idx[:, None], idx[None, :]] = Q
        return R

    def _transform(self, X: torch.Tensor) -> torch.Tensor:
        # X is (..., n, d). Right-multiply by R^T to rotate features.
        R = self._rotation()
        return torch.matmul(X, R.transpose(-1, -2))

    def forward(self, x1, x2, diag=False, **params):
        return self.base(self._transform(x1), self._transform(x2), diag=diag, **params)

And then you can use it like:

import gpytorch
from gpytorch.priors import NormalPrior
from gpytorch.means import ConstantMean
from botorch.models.gp_regression import SingleTaskGP

def make_latent_gp(train_X, train_Y, skew_dims=True, **kwargs):
    m = train_Y.shape[-1]
    aug_batch_shape = train_X.shape[:-2] + (torch.Size([m]) if m > 1 else torch.Size())

    mean = ConstantMean(batch_shape=aug_batch_shape, constant_prior=NormalPrior(0.0, 1.0))
    covar = RotatedInputsKernel(d=train_X.shape[-1], batch_shape=aug_batch_shape, skew_dims=skew_dims)

    return SingleTaskGP(
        train_X=train_X,
        train_Y=train_Y,
        mean_module=mean,
        covar_module=covar,
        outcome_transform=None,  # keep if you truly need it
        **kwargs,
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions