Skip to content

Commit dc987f5

Browse files
committed
Boltzmann sampling function added in utils/sampling to remove duplicate
code, reshuffling of other sampling methods (that don't take an acqf)
1 parent 9a7c517 commit dc987f5

File tree

4 files changed

+353
-261
lines changed

4 files changed

+353
-261
lines changed

botorch/optim/initializers.py

Lines changed: 24 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@
4444
from botorch.utils.multi_objective.pareto import is_non_dominated
4545
from botorch.utils.sampling import (
4646
batched_multinomial,
47+
boltzmann_sample,
4748
draw_sobol_samples,
4849
get_polytope_samples,
4950
manual_seed,
51+
sample_perturbed_subset_dims,
52+
sample_truncated_normal_perturbations,
5053
)
5154
from botorch.utils.transforms import normalize, standardize, unnormalize
5255
from torch import Tensor
53-
from torch.distributions import Normal
56+
from torch.distributions import Multinomial, Normal
5457
from torch.quasirandom import SobolEngine
5558

5659
TGenInitialConditions = Callable[
@@ -578,10 +581,12 @@ def gen_one_shot_kg_initial_conditions(
578581

579582
# sampling from the optimizers
580583
n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs
581-
eta = options.get("eta", 2.0)
582-
weights = torch.exp(eta * standardize(fantasy_vals))
583-
idx = torch.multinomial(weights, num_restarts * n_value, replacement=True)
584-
584+
idx = boltzmann_sample(
585+
function_values=fantasy_vals,
586+
num_samples=num_restarts * n_value,
587+
eta=options.get("eta", 2.0),
588+
replacement=True,
589+
)
585590
# set the respective initial conditions to the sampled optimizers
586591
ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1)
587592
return ics
@@ -699,14 +704,14 @@ def gen_one_shot_hvkg_initial_conditions(
699704
sequential=False,
700705
)
701706
# sampling from the optimizers
702-
eta = options.get("eta", 2.0)
703707
if num_optim_restarts > 0:
704-
probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0)
705-
idx = torch.multinomial(
706-
probs,
707-
num_optim_restarts * acq_function.num_fantasies,
708+
idx = boltzmann_sample(
709+
function_values=fantasy_vals,
710+
num_samples=num_optim_restarts * acq_function.num_fantasies,
711+
eta=options.get("eta", 2.0),
708712
replacement=True,
709713
)
714+
710715
optim_ics = fantasy_cands[idx]
711716
if is_mf_hvkg:
712717
# add fixed features
@@ -885,11 +890,10 @@ def gen_value_function_initial_conditions(
885890
# sampling from the optimizers
886891
n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
887892
if n_value > 0:
888-
eta = options.get("eta", 2.0)
889-
weights = torch.exp(eta * standardize(fantasy_vals))
890-
idx = batched_multinomial(
891-
weights=weights.expand(*batch_shape, -1),
893+
idx = boltzmann_sample(
894+
function_values=fantasy_vals.expand(*batch_shape, -1),
892895
num_samples=n_value,
896+
eta=options.get("eta", 2.0),
893897
replacement=True,
894898
).permute(-1, *range(len(batch_shape)))
895899
resampled = fantasy_cands[idx]
@@ -979,18 +983,12 @@ def initialize_q_batch(
979983
return X[idcs], acq_vals[idcs]
980984

981985
max_val, max_idx = torch.max(acq_vals, dim=0)
982-
Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd
983-
etaZ = eta * Z
984-
weights = torch.exp(etaZ)
985-
while torch.isinf(weights).any():
986-
etaZ *= 0.5
987-
weights = torch.exp(etaZ)
988-
if batch_shape == torch.Size():
989-
idcs = torch.multinomial(weights, n)
990-
else:
991-
idcs = batched_multinomial(
992-
weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n
993-
).permute(-1, *range(len(batch_shape)))
986+
idcs = boltzmann_sample(
987+
acq_vals.permute(*range(1, len(batch_shape) + 1), 0),
988+
num_samples=n,
989+
eta=eta,
990+
).permute(-1, *range(len(batch_shape)))
991+
994992
# make sure we get the maximum
995993
if max_idx not in idcs:
996994
idcs[-1] = max_idx
@@ -1239,133 +1237,6 @@ def sample_points_around_best(
12391237
return perturbed_X
12401238

12411239

1242-
def sample_truncated_normal_perturbations(
1243-
X: Tensor,
1244-
n_discrete_points: int,
1245-
sigma: float,
1246-
bounds: Tensor,
1247-
qmc: bool = True,
1248-
) -> Tensor:
1249-
r"""Sample points around `X`.
1250-
1251-
Sample perturbed points around `X` such that the added perturbations
1252-
are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d.
1253-
1254-
Args:
1255-
X: A `n x d`-dim tensor starting points.
1256-
n_discrete_points: The number of points to sample.
1257-
sigma: The standard deviation of the additive gaussian noise for
1258-
perturbing the points.
1259-
bounds: A `2 x d`-dim tensor containing the bounds.
1260-
qmc: A boolean indicating whether to use qmc.
1261-
1262-
Returns:
1263-
A `n_discrete_points x d`-dim tensor containing the sampled points.
1264-
"""
1265-
X = normalize(X, bounds=bounds)
1266-
d = X.shape[1]
1267-
# sample points from N(X_center, sigma^2 I), truncated to be within
1268-
# [0, 1]^d.
1269-
if X.shape[0] > 1:
1270-
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
1271-
X = X[rand_indices]
1272-
if qmc:
1273-
std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device)
1274-
std_bounds[1] = 1
1275-
u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1)
1276-
else:
1277-
u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device)
1278-
# compute bounds to sample from
1279-
a = -X
1280-
b = 1 - X
1281-
# compute z-score of bounds
1282-
alpha = a / sigma
1283-
beta = b / sigma
1284-
normal = Normal(0, 1)
1285-
cdf_alpha = normal.cdf(alpha)
1286-
# use inverse transform
1287-
perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma
1288-
# add perturbation and clip points that are still outside
1289-
perturbed_X = (X + perturbation).clamp(0.0, 1.0)
1290-
return unnormalize(perturbed_X, bounds=bounds)
1291-
1292-
1293-
def sample_perturbed_subset_dims(
1294-
X: Tensor,
1295-
bounds: Tensor,
1296-
n_discrete_points: int,
1297-
sigma: float = 1e-1,
1298-
qmc: bool = True,
1299-
prob_perturb: float | None = None,
1300-
) -> Tensor:
1301-
r"""Sample around `X` by perturbing a subset of the dimensions.
1302-
1303-
By default, dimensions are perturbed with probability equal to
1304-
`min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number
1305-
of dimensions can be beneificial. The perturbations are sampled
1306-
from N(0, sigma^2 I) and truncated to be within [0,1]^d.
1307-
1308-
Args:
1309-
X: A `n x d`-dim tensor starting points. `X`
1310-
must be normalized to be within `[0, 1]^d`.
1311-
bounds: The bounds to sample perturbed values from
1312-
n_discrete_points: The number of points to sample.
1313-
sigma: The standard deviation of the additive gaussian noise for
1314-
perturbing the points.
1315-
qmc: A boolean indicating whether to use qmc.
1316-
prob_perturb: The probability of perturbing each dimension. If omitted,
1317-
defaults to `min(20 / d, 1)`.
1318-
1319-
Returns:
1320-
A `n_discrete_points x d`-dim tensor containing the sampled points.
1321-
1322-
"""
1323-
if bounds.ndim != 2:
1324-
raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.")
1325-
elif X.ndim != 2:
1326-
raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.")
1327-
d = bounds.shape[-1]
1328-
if prob_perturb is None:
1329-
# Only perturb a subset of the features
1330-
prob_perturb = min(20.0 / d, 1.0)
1331-
1332-
if X.shape[0] == 1:
1333-
X_cand = X.repeat(n_discrete_points, 1)
1334-
else:
1335-
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
1336-
X_cand = X[rand_indices]
1337-
pert = sample_truncated_normal_perturbations(
1338-
X=X_cand,
1339-
n_discrete_points=n_discrete_points,
1340-
sigma=sigma,
1341-
bounds=bounds,
1342-
qmc=qmc,
1343-
)
1344-
1345-
# find cases where we are not perturbing any dimensions
1346-
mask = (
1347-
torch.rand(
1348-
n_discrete_points,
1349-
d,
1350-
dtype=bounds.dtype,
1351-
device=bounds.device,
1352-
)
1353-
<= prob_perturb
1354-
)
1355-
ind = (~mask).all(dim=-1).nonzero()
1356-
# perturb `n_perturb` of the dimensions
1357-
n_perturb = ceil(d * prob_perturb)
1358-
perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device)
1359-
perturb_mask[:n_perturb].fill_(1)
1360-
# TODO: use batched `torch.randperm` when available:
1361-
# https://github.com/pytorch/pytorch/issues/42502
1362-
for idx in ind:
1363-
mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)]
1364-
# Create candidate points
1365-
X_cand[mask] = pert[mask]
1366-
return X_cand
1367-
1368-
13691240
def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
13701241
r"""Determine whether a given acquisition function is non-negative.
13711242

0 commit comments

Comments
 (0)