1818import torch
1919from botorch .acquisition .objective import PosteriorTransform
2020from botorch .exceptions .errors import UnsupportedError
21-
2221from botorch .logging import logger
2322from botorch .models .model import Model
2423from botorch .models .transforms .input import InputTransform
24+ from botorch .utils .transforms import match_batch_shape
2525from botorch_community .models .utils .prior_fitted_network import (
2626 download_model ,
2727 ModelPaths ,
2828)
2929from botorch_community .posteriors .riemann import BoundedRiemannPosterior
30+ from gpytorch .likelihoods .gaussian_likelihood import FixedNoiseGaussianLikelihood
3031from pfns .train import MainConfig # @manual=//pytorch/PFNs:PFNs
3132from torch import Tensor
3233from torch .nn import Module
@@ -58,7 +59,7 @@ def __init__(
5859
5960 Args:
6061 train_X: A `n x d` tensor of training features.
61- train_Y: A `n x m ` tensor of training observations.
62+ train_Y: A `n x 1 ` tensor of training observations.
6263 model: A pre-trained PFN model with the following
6364 forward(train_X, train_Y, X) -> logit predictions of shape
6465 `n x b x c` where c is the number of discrete buckets
@@ -95,40 +96,35 @@ def __init__(
9596 if train_Yvar is not None :
9697 logger .debug ("train_Yvar provided but ignored for PFNModel." )
9798
98- if not ( 1 <= train_Y .dim () <= 3 ) :
99- raise UnsupportedError ("train_Y must be 1- to 3 -dimensional." )
99+ if train_Y .dim () != 2 :
100+ raise UnsupportedError ("train_Y must be 2 -dimensional." )
100101
101- if not ( 2 <= train_X .dim () <= 3 ) :
102- raise UnsupportedError ("train_X must be 2- to 3- dimensional." )
102+ if train_X .dim () != 2 :
103+ raise UnsupportedError ("train_X must be 2-dimensional." )
103104
104- if train_Y .dim () == train_X .dim ():
105- if train_Y .shape [- 1 ] > 1 :
106- raise UnsupportedError ("Only 1 target allowed for PFNModel." )
107- train_Y = train_Y .squeeze (- 1 )
105+ if train_Y .shape [- 1 ] > 1 :
106+ raise UnsupportedError ("Only 1 target allowed for PFNModel." )
108107
109- if (len (train_X .shape ) != len (train_Y .shape ) + 1 ) or (
110- train_Y .shape != train_X .shape [:- 1 ]
111- ):
108+ if train_X .shape [0 ] != train_Y .shape [0 ]:
112109 raise UnsupportedError (
113- "train_X and train_Y must have the same shape except "
114- "for the last dimension."
110+ "train_X and train_Y must have the same number of rows."
115111 )
116112
117- if len (train_X .shape ) == 2 :
118- # adding batch dimension
119- train_X = train_X .unsqueeze (0 )
120- train_Y = train_Y .unsqueeze (0 )
121-
122113 with torch .no_grad ():
123114 self .transformed_X = self .transform_inputs (
124115 X = train_X , input_transform = input_transform
125116 )
126117
127- self .train_X = train_X # shape: `b x n x d`
128- self .train_Y = train_Y # shape: `b x n`
129- self .pfn = model .to (train_X .device )
118+ self .train_X = train_X # shape: (n, d)
119+ self .train_Y = train_Y # shape: (n, 1)
120+ # Downstream botorch tooling expects a likelihood to be specified,
121+ # so here we use a FixedNoiseGaussianLikelihood that is unused.
122+ if train_Yvar is None :
123+ train_Yvar = torch .zeros_like (train_Y )
124+ self .likelihood = FixedNoiseGaussianLikelihood (noise = train_Yvar )
125+ self .pfn = model .to (device = train_X .device )
130126 self .batch_first = batch_first
131- self .constant_model_kwargs = constant_model_kwargs
127+ self .constant_model_kwargs = constant_model_kwargs or {}
132128 if input_transform is not None :
133129 self .input_transform = input_transform
134130
@@ -146,23 +142,19 @@ def posterior(
146142 any `model.forward` or `model.likelihood` calls.
147143
148144 Args:
149- X: A `b'? x b? x q x d`-dim Tensor, where `d` is the dimension of the
150- feature space, `q` is the number of points considered jointly,
151- and `b` is the batch dimension.
152- We only allow `q=1` for PFNModel, so q can also be omitted, i.e.
153- `b x d`-dim Tensor.
154- **Currently not supported for PFNModel**.
145+ X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
146+ feature space.
155147 output_indices: **Currenlty not supported for PFNModel.**
156148 observation_noise: **Currently not supported for PFNModel**.
157149 posterior_transform: **Currently not supported for PFNModel**.
158150
159151 Returns:
160- A `BoundedRiemannPosterior` object , representing a batch of `b` joint
161- distributions over `q` points and `m` outputs each .
152+ A `BoundedRiemannPosterior`, representing a batch of b? x q?`
153+ distributions.
162154 """
163155 self .pfn .eval ()
164156 if output_indices is not None :
165- raise RuntimeError (
157+ raise UnsupportedError (
166158 "output_indices is not None. PFNModel should not "
167159 "be a multi-output model."
168160 )
@@ -173,60 +165,54 @@ def posterior(
173165 if posterior_transform is not None :
174166 raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
175167
176- if not (1 <= len (X .shape ) <= 4 ):
177- raise UnsupportedError ("X must be 1- to 4-dimensional." )
178-
179- # X has shape b'? x b? x q? x d
180-
181- orig_X_shape = X .shape
182- q_in_orig_X_shape = len (X .shape ) > 2
183-
184- if len (X .shape ) == 1 :
185- X = X .unsqueeze (0 ).unsqueeze (0 ).unsqueeze (0 ) # shape `b'=1 x b=1 x q=1 x d`
186- elif len (X .shape ) == 2 :
187- X = X .unsqueeze (1 ).unsqueeze (1 ) # shape `b' x b=1 x q=1 x d`
188- elif len (X .shape ) == 3 :
189- if self .train_X .shape [0 ] == 1 :
190- X = X .unsqueeze (1 ) # shape `b' x b=1 x q x d`
191- else :
192- X = X .unsqueeze (0 ) # shape `b'=1 x b x q x d`
193-
194- # X has shape `b' x b x q x d`
195-
196- if X .shape [2 ] != 1 :
197- raise UnsupportedError ("Only q=1 is supported for PFNModel." )
198-
199- # X has shape `b' x b x q=1 x d`
200- X = self .transform_inputs (X )
201- train_X = self .transformed_X # shape `b x n x d`
202- train_Y = self .train_Y # shape `b x n`
203- folded_X = X .transpose (0 , 2 ).squeeze (0 ) # shape `b x b' x d
204-
205- constant_model_kwargs = self .constant_model_kwargs or {}
206-
207- if self .batch_first :
208- logits = self .pfn (
209- train_X .float (),
210- train_X .float (),
211- folded_X .float (),
212- ** constant_model_kwargs ,
213- ).transpose (0 , 1 )
214- else :
215- logits = self .pfn (
216- train_X .float ().transpose (0 , 1 ),
217- train_Y .float ().transpose (0 , 1 ),
218- folded_X .float ().transpose (0 , 1 ),
219- ** constant_model_kwargs ,
220- )
221-
222- # logits shape `b' x b x logits_dim`
168+ orig_X_shape = X .shape # X has shape b? x q? x d
169+ X = self .prepare_X (X ) # shape (b, q, d)
170+ train_X = match_batch_shape (self .transformed_X , X ) # shape (b, n, d)
171+ train_Y = match_batch_shape (self .train_Y , X ) # shape (b, n, 1)
223172
224- logits = logits .view (
173+ probabilities = self .pfn_predict (
174+ X = X , train_X = train_X , train_Y = train_Y
175+ ) # (b, q, num_buckets)
176+ probabilities = probabilities .view (
225177 * orig_X_shape [:- 1 ], - 1
226- ) # orig shape w/o q but logits_dim at end: `b'? x b? x q? x logits_dim`
227- if q_in_orig_X_shape :
228- logits = logits .squeeze (- 2 ) # shape `b'? x b? x logits_dim`
178+ ) # (b?, q?, num_buckets)
229179
230- probabilities = logits .softmax (dim = - 1 )
180+ # Get posterior with the right dtype
181+ borders = self .pfn .criterion .borders .to (X .dtype )
182+ return BoundedRiemannPosterior (
183+ borders = borders ,
184+ probabilities = probabilities ,
185+ )
231186
232- return BoundedRiemannPosterior (self .pfn .criterion .borders , probabilities )
187+ def prepare_X (self , X : Tensor ) -> Tensor :
188+ if len (X .shape ) > 3 :
189+ raise UnsupportedError (f"X must be at most 3-d, got { X .shape } ." )
190+ while len (X .shape ) < 3 :
191+ X = X .unsqueeze (0 )
192+
193+ X = self .transform_inputs (X ) # shape (b , q, d)
194+ return X
195+
196+ def pfn_predict (self , X : Tensor , train_X : Tensor , train_Y : Tensor ) -> Tensor :
197+ """
198+ X has shape (b, q, d)
199+ train_X has shape (b, n, d)
200+ train_Y has shape (b, n, 1)
201+ """
202+ if not self .batch_first :
203+ X = X .transpose (0 , 1 ) # shape (q, b, d)
204+ train_X = train_X .transpose (0 , 1 ) # shape (n, b, d)
205+ train_Y = train_Y .transpose (0 , 1 ) # shape (n, b, 1)
206+
207+ logits = self .pfn (
208+ train_X .float (),
209+ train_Y .float (),
210+ X .float (),
211+ ** self .constant_model_kwargs ,
212+ )
213+ if not self .batch_first :
214+ logits = logits .transpose (0 , 1 ) # shape (b, q, num_buckets)
215+ logits = logits .to (X .dtype )
216+
217+ probabilities = logits .softmax (dim = - 1 ) # shape (b, q, num_buckets)
218+ return probabilities
0 commit comments