@@ -55,6 +55,7 @@ def __init__(
5555 self ,
5656 model : Model ,
5757 posterior_transform : PosteriorTransform | None = None ,
58+ allow_multi_output : bool = False ,
5859 ) -> None :
5960 r"""Base constructor for analytic acquisition functions.
6061
@@ -63,10 +64,12 @@ def __init__(
6364 posterior_transform: A PosteriorTransform. If using a multi-output model,
6465 a PosteriorTransform that transforms the multi-output posterior into a
6566 single-output posterior is required.
67+ allow_multi_output: If False, requires a posterior_transform if a
68+ multi-output model is passed.
6669 """
6770 super ().__init__ (model = model )
6871 if posterior_transform is None :
69- if model .num_outputs != 1 :
72+ if not allow_multi_output and model .num_outputs != 1 :
7073 raise UnsupportedError (
7174 "Must specify a posterior transform when using a "
7275 "multi-output model."
@@ -89,21 +92,21 @@ def _mean_and_sigma(
8992 """Computes the first and second moments of the model posterior.
9093
9194 Args:
92- X: `batch_shape x q x d`-dim Tensor of model inputs .
95+ X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points .
9396 compute_sigma: Boolean indicating whether or not to compute the second
9497 moment (default: True).
9598 min_var: The minimum value the variance is clamped too. Should be positive.
9699
97100 Returns:
98- A tuple of tensors containing the first and second moments of the model
99- posterior. Removes the last two dimensions if they have size one. Only
100- returns a single tensor of means if compute_sigma is True .
101+ A tuple of tensors of shape `(b1 x ... x bk) x m` containing the first and
102+ second moments of the model posterior, where `m` is the number of outputs.
103+ Returns `None` instead of the second tensor if ` compute_sigma` is False .
101104 """
102105 self .to (X ) # ensures buffers / parameters are on the same device and dtype
103106 posterior = self .model .posterior (
104107 X = X , posterior_transform = self .posterior_transform
105108 )
106- mean = posterior .mean .squeeze (- 2 ). squeeze ( - 1 ) # removing redundant dimensions
109+ mean = posterior .mean .squeeze (- 2 ) # remove q-batch dimension
107110 if not compute_sigma :
108111 return mean , None
109112 sigma = posterior .variance .clamp_min (min_var ).sqrt ().view (mean .shape )
@@ -168,9 +171,9 @@ def forward(self, X: Tensor) -> Tensor:
168171 A `(b1 x ... bk)`-dim tensor of Log Probability of Improvement values at
169172 the given design points `X`.
170173 """
171- mean , sigma = self ._mean_and_sigma (X )
174+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
172175 u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
173- return log_Phi (u )
176+ return log_Phi (u . squeeze ( - 1 ) )
174177
175178
176179class ProbabilityOfImprovement (AnalyticAcquisitionFunction ):
@@ -223,9 +226,9 @@ def forward(self, X: Tensor) -> Tensor:
223226 A `(b1 x ... bk)`-dim tensor of Probability of Improvement values at the
224227 given design points `X`.
225228 """
226- mean , sigma = self ._mean_and_sigma (X )
229+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
227230 u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
228- return Phi (u )
231+ return Phi (u . squeeze ( - 1 ) )
229232
230233
231234class qAnalyticProbabilityOfImprovement (AnalyticAcquisitionFunction ):
@@ -354,9 +357,9 @@ def forward(self, X: Tensor) -> Tensor:
354357 A `(b1 x ... bk)`-dim tensor of Expected Improvement values at the
355358 given design points `X`.
356359 """
357- mean , sigma = self ._mean_and_sigma (X )
360+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
358361 u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
359- return sigma * _ei_helper (u )
362+ return ( sigma * _ei_helper (u )). squeeze ( - 1 )
360363
361364
362365class LogExpectedImprovement (AnalyticAcquisitionFunction ):
@@ -418,9 +421,9 @@ def forward(self, X: Tensor) -> Tensor:
418421 A `(b1 x ... bk)`-dim tensor of the logarithm of the Expected Improvement
419422 values at the given design points `X`.
420423 """
421- mean , sigma = self ._mean_and_sigma (X )
424+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
422425 u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
423- return _log_ei_helper (u ) + sigma .log ()
426+ return ( _log_ei_helper (u ) + sigma .log ()). squeeze ( - 1 )
424427
425428
426429class ConstrainedAnalyticAcquisitionFunctionMixin (ABC ):
@@ -433,7 +436,7 @@ def __init__(
433436 r"""Analytic Log Probability of Feasibility.
434437
435438 Args:
436- model: A fitted multi-output model.
439+ model: A fitted single- or multi-output model.
437440 constraints: A dictionary of the form `{i: [lower, upper]}`, where
438441 `i` is the output index, and `lower` and `upper` are lower and upper
439442 bounds on that output (resp. interpreted as -Inf / Inf if None).
@@ -501,13 +504,11 @@ def _compute_log_prob_feas(
501504 r"""Compute logarithm of the feasibility probability for each batch of X.
502505
503506 Args:
504- X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
505- points each.
506507 means: A `(b) x m`-dim Tensor of means.
507508 sigmas: A `(b) x m`-dim Tensor of standard deviations.
508509
509510 Returns:
510- A `b `-dim tensor of log feasibility probabilities
511+ A `(b) `-dim tensor of log feasibility probabilities
511512
512513 Note: This function does case-work for upper bound, lower bound, and both-sided
513514 bounds. Another way to do it would be to use 'inf' and -'inf' for the
@@ -567,7 +568,7 @@ def __init__(
567568 r"""Analytic Log Constrained Expected Improvement.
568569
569570 Args:
570- model: A fitted multi-output model.
571+ model: A fitted single- or multi-output model.
571572 best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
572573 the best feasible function value observed so far (assumed noiseless).
573574 objective_index: The index of the objective.
@@ -576,8 +577,7 @@ def __init__(
576577 bounds on that output (resp. interpreted as -Inf / Inf if None)
577578 maximize: If True, consider the problem a maximization problem.
578579 """
579- # Use AcquisitionFunction constructor to avoid check for posterior transform.
580- AcquisitionFunction .__init__ (self , model = model )
580+ super ().__init__ (model = model , allow_multi_output = True )
581581 self .posterior_transform = None
582582 self .maximize = maximize
583583 self .objective_index = objective_index
@@ -641,13 +641,12 @@ def __init__(
641641 r"""Analytic Log Probability of Feasibility.
642642
643643 Args:
644- model: A fitted multi-output model.
644+ model: A fitted single- or multi-output model.
645645 constraints: A dictionary of the form `{i: [lower, upper]}`, where
646646 `i` is the output index, and `lower` and `upper` are lower and upper
647647 bounds on that output (resp. interpreted as -Inf / Inf if None)
648648 """
649- # Use AcquisitionFunction constructor to avoid check for posterior transform.
650- AcquisitionFunction .__init__ (self , model = model )
649+ super ().__init__ (model = model , allow_multi_output = True )
651650 self .posterior_transform = None
652651 ConstrainedAnalyticAcquisitionFunctionMixin .__init__ (self , constraints )
653652
@@ -708,7 +707,7 @@ def __init__(
708707 r"""Analytic Constrained Expected Improvement.
709708
710709 Args:
711- model: A fitted multi-output model.
710+ model: A fitted single- or multi-output model.
712711 best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
713712 the best feasible function value observed so far (assumed noiseless).
714713 objective_index: The index of the objective.
@@ -718,8 +717,7 @@ def __init__(
718717 maximize: If True, consider the problem a maximization problem.
719718 """
720719 legacy_ei_numerics_warning (legacy_name = type (self ).__name__ )
721- # Use AcquisitionFunction constructor to avoid check for posterior transform.
722- AcquisitionFunction .__init__ (self , model = model )
720+ super ().__init__ (model = model , allow_multi_output = True )
723721 self .posterior_transform = None
724722 self .maximize = maximize
725723 self .objective_index = objective_index
@@ -828,7 +826,9 @@ def forward(self, X: Tensor) -> Tensor:
828826 the given design points `X`.
829827 """
830828 # add batch dimension for broadcasting to fantasy models
829+ # (b1 x ... x bk) x num_fantasies x 1
831830 mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 ))
831+ mean , sigma = mean .squeeze (- 1 ), sigma .squeeze (- 1 )
832832 u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
833833 log_ei = _log_ei_helper (u ) + sigma .log ()
834834 # this is mathematically - though not numerically - equivalent to log(mean(ei))
@@ -906,14 +906,16 @@ def forward(self, X: Tensor) -> Tensor:
906906 r"""Evaluate Expected Improvement on the candidate set X.
907907
908908 Args:
909- X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
909+ X: A `( b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
910910
911911 Returns:
912- A `b1 x ... bk`-dim tensor of Noisy Expected Improvement values at
912+ A `( b1 x ... bk) `-dim tensor of Noisy Expected Improvement values at
913913 the given design points `X`.
914914 """
915915 # add batch dimension for broadcasting to fantasy models
916- mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 ))
916+ # (b1 x ... x bk) x num_fantasies x 1
917+ mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 )) # (b1 x ... x bk) x m1 x 1
918+ mean , sigma = mean .squeeze (- 1 ), sigma .squeeze (- 1 )
917919 u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
918920 return (sigma * _ei_helper (u )).mean (dim = - 1 )
919921
@@ -970,8 +972,9 @@ def forward(self, X: Tensor) -> Tensor:
970972 A `(b1 x ... bk)`-dim tensor of Upper Confidence Bound values at the
971973 given design points `X`.
972974 """
973- mean , sigma = self ._mean_and_sigma (X )
974- return (mean if self .maximize else - mean ) + self .beta .sqrt () * sigma
975+ mean , sigma = self ._mean_and_sigma (X ) # (b1 x ... x bk) x 1
976+ ucb = (mean if self .maximize else - mean ) + self .beta .sqrt () * sigma
977+ return ucb .squeeze (- 1 )
975978
976979
977980class PosteriorMean (AnalyticAcquisitionFunction ):
@@ -1020,8 +1023,10 @@ def forward(self, X: Tensor) -> Tensor:
10201023 A `(b1 x ... bk)`-dim tensor of Posterior Mean values at the
10211024 given design points `X`.
10221025 """
1023- mean , _ = self ._mean_and_sigma (X , compute_sigma = False )
1024- return mean if self .maximize else - mean
1026+ mean , _ = self ._mean_and_sigma (X , compute_sigma = False ) # (b1 x ... x bk) x 1
1027+ if not self .maximize :
1028+ mean = - mean
1029+ return mean .squeeze (- 1 )
10251030
10261031
10271032class ScalarizedPosteriorMean (AnalyticAcquisitionFunction ):
@@ -1056,14 +1061,16 @@ def forward(self, X: Tensor) -> Tensor:
10561061 r"""Evaluate the scalarized posterior mean on the candidate set X.
10571062
10581063 Args:
1059- X: A `(b ) x q x d`-dim Tensor of `(b)` t-batches of `d`-dim design
1060- points each.
1064+ X: A `(b1 x ... x bk ) x q x d`-dim Tensor of `(b1 x ... x bk)`
1065+ t-batches of `d`-dim design points each.
10611066
10621067 Returns:
1063- A `(b )`-dim Tensor of Posterior Mean values at the given design
1064- points `X`.
1068+ A `(b1 x ... x bk )`-dim Tensor of Posterior Mean values at the given
1069+ design points `X`.
10651070 """
1066- return self ._mean_and_sigma (X , compute_sigma = False )[0 ] @ self .weights
1071+ # (b1 x ... x bk) x q x 1
1072+ mean , _ = self ._mean_and_sigma (X , compute_sigma = False )
1073+ return mean .squeeze (- 1 ) @ self .weights
10671074
10681075
10691076class PosteriorStandardDeviation (AnalyticAcquisitionFunction ):
@@ -1131,7 +1138,9 @@ def forward(self, X: Tensor) -> Tensor:
11311138 given design points `X`.
11321139 """
11331140 _ , std = self ._mean_and_sigma (X )
1134- return std if self .maximize else - std
1141+ if not self .maximize :
1142+ std = - std
1143+ return std .view (X .shape [:- 2 ])
11351144
11361145
11371146# --------------- Helper functions for analytic acquisition functions. ---------------
0 commit comments