4141from botorch .utils .transforms import normalize , unnormalize
4242from scipy .spatial import Delaunay , HalfspaceIntersection
4343from torch import LongTensor , Tensor
44- < << << << HEAD
4544from torch .distributions import Normal
46- == == == =
47- from torch .distributions import Multinomial , Normal
48- > >> >> >> Boltzmann sampling function added in utils / sampling to remove duplicate
4945from torch .quasirandom import SobolEngine
5046
5147
@@ -1003,10 +999,12 @@ def sparse_to_dense_constraints(
1003999def optimize_posterior_samples (
10041000 paths : GenericDeterministicModel ,
10051001 bounds : Tensor ,
1006- raw_samples : int = 1024 ,
1007- num_restarts : int = 20 ,
1002+ raw_samples : int = 2048 ,
1003+ num_restarts : int = 4 ,
10081004 sample_transform : Callable [[Tensor ], Tensor ] | None = None ,
10091005 return_transformed : bool = False ,
1006+ suggested_points : Tensor | None = None ,
1007+ options : dict | None = None ,
10101008) -> tuple [Tensor , Tensor ]:
10111009 r"""Cheaply maximizes posterior samples by random querying followed by
10121010 gradient-based optimization using SciPy's L-BFGS-B routine.
@@ -1015,19 +1013,27 @@ def optimize_posterior_samples(
10151013 paths: Random Fourier Feature-based sample paths from the GP
10161014 bounds: The bounds on the search space.
10171015 raw_samples: The number of samples with which to query the samples initially.
1016+ Raw samples are cheap to evaluate, so this should ideally be set much higher
1017+ than num_restarts.
10181018 num_restarts: The number of points selected for gradient-based optimization.
1019+ Should be set low relative to the number of raw
10191020 sample_transform: A callable transform of the sample outputs (e.g.
10201021 MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
10211022 negate the objective or otherwise transform the output.
10221023 return_transformed: A boolean indicating whether to return the transformed
10231024 or non-transformed samples.
1025+ suggested_points: Tensor of suggested input locations that are high-valued.
1026+ These are more densely evaluated during the sampling phase of optimization.
1027+ options: Options for generation of initial candidates, passed to
1028+ gen_batch_initial_conditions.
10241029
10251030 Returns:
10261031 A two-element tuple containing:
10271032 - X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
10281033 - f_opt: A `num_optima x [batch_size] x m`-dim, optionally
10291034 `num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
10301035 """
1036+ options = {} if options is None else options
10311037
10321038 def path_func (x ) -> Tensor :
10331039 res = paths (x )
@@ -1036,21 +1042,35 @@ def path_func(x) -> Tensor:
10361042
10371043 return res .squeeze (- 1 )
10381044
1039- candidate_set = unnormalize (
1040- SobolEngine (dimension = bounds .shape [1 ], scramble = True ).draw (n = raw_samples ),
1041- bounds = bounds ,
1042- )
10431045 # queries all samples on all candidates - output shape
10441046 # raw_samples * num_optima * num_models
1047+ frac_random = 1 if suggested_points is None else options .get ("frac_random" , 0.9 )
1048+ candidate_set = draw_sobol_samples (
1049+ bounds = bounds , n = round (raw_samples * frac_random ), q = 1
1050+ ).squeeze (- 2 )
1051+ if frac_random < 1 :
1052+ perturbed_suggestions = sample_truncated_normal_perturbations (
1053+ X = suggested_points ,
1054+ n_discrete_points = round (raw_samples * (1 - frac_random )),
1055+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
1056+ bounds = bounds ,
1057+ )
1058+ candidate_set = torch .cat ((candidate_set , perturbed_suggestions ))
1059+
10451060 candidate_queries = path_func (candidate_set )
1046- argtop_k = torch .topk (candidate_queries , num_restarts , dim = - 1 ).indices
1047- X_top_k = candidate_set [argtop_k , :]
1061+ idx = boltzmann_sample (
1062+ function_values = candidate_queries ,
1063+ num_samples = num_restarts ,
1064+ eta = options .get ("eta" , 5.0 ),
1065+ replacement = False ,
1066+ )
1067+ ics = candidate_set [idx , :]
10481068
10491069 # to avoid circular import, the import occurs here
10501070 from botorch .generation .gen import gen_candidates_scipy
10511071
10521072 X_top_k , f_top_k = gen_candidates_scipy (
1053- X_top_k ,
1073+ ics ,
10541074 path_func ,
10551075 lower_bounds = bounds [0 ],
10561076 upper_bounds = bounds [1 ],
0 commit comments