|
44 | 44 | from botorch.utils.multi_objective.pareto import is_non_dominated
|
45 | 45 | from botorch.utils.sampling import (
|
46 | 46 | batched_multinomial,
|
| 47 | + boltzmann_sample, |
47 | 48 | draw_sobol_samples,
|
48 | 49 | get_polytope_samples,
|
49 | 50 | manual_seed,
|
| 51 | + sample_perturbed_subset_dims, |
| 52 | + sample_truncated_normal_perturbations, |
50 | 53 | )
|
51 | 54 | from botorch.utils.transforms import normalize, standardize, unnormalize
|
52 | 55 | from torch import Tensor
|
53 |
| -from torch.distributions import Normal |
| 56 | +from torch.distributions import Multinomial, Normal |
54 | 57 | from torch.quasirandom import SobolEngine
|
55 | 58 |
|
56 | 59 | TGenInitialConditions = Callable[
|
@@ -578,10 +581,12 @@ def gen_one_shot_kg_initial_conditions(
|
578 | 581 |
|
579 | 582 | # sampling from the optimizers
|
580 | 583 | n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs
|
581 |
| - eta = options.get("eta", 2.0) |
582 |
| - weights = torch.exp(eta * standardize(fantasy_vals)) |
583 |
| - idx = torch.multinomial(weights, num_restarts * n_value, replacement=True) |
584 |
| - |
| 584 | + idx = boltzmann_sample( |
| 585 | + function_values=fantasy_vals, |
| 586 | + num_samples=num_restarts * n_value, |
| 587 | + eta=options.get("eta", 2.0), |
| 588 | + replacement=True, |
| 589 | + ) |
585 | 590 | # set the respective initial conditions to the sampled optimizers
|
586 | 591 | ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1)
|
587 | 592 | return ics
|
@@ -699,14 +704,14 @@ def gen_one_shot_hvkg_initial_conditions(
|
699 | 704 | sequential=False,
|
700 | 705 | )
|
701 | 706 | # sampling from the optimizers
|
702 |
| - eta = options.get("eta", 2.0) |
703 | 707 | if num_optim_restarts > 0:
|
704 |
| - probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0) |
705 |
| - idx = torch.multinomial( |
706 |
| - probs, |
707 |
| - num_optim_restarts * acq_function.num_fantasies, |
| 708 | + idx = boltzmann_sample( |
| 709 | + function_values=fantasy_vals, |
| 710 | + num_samples=num_optim_restarts * acq_function.num_fantasies, |
| 711 | + eta=options.get("eta", 2.0), |
708 | 712 | replacement=True,
|
709 | 713 | )
|
| 714 | + |
710 | 715 | optim_ics = fantasy_cands[idx]
|
711 | 716 | if is_mf_hvkg:
|
712 | 717 | # add fixed features
|
@@ -885,11 +890,10 @@ def gen_value_function_initial_conditions(
|
885 | 890 | # sampling from the optimizers
|
886 | 891 | n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
|
887 | 892 | if n_value > 0:
|
888 |
| - eta = options.get("eta", 2.0) |
889 |
| - weights = torch.exp(eta * standardize(fantasy_vals)) |
890 |
| - idx = batched_multinomial( |
891 |
| - weights=weights.expand(*batch_shape, -1), |
| 893 | + idx = boltzmann_sample( |
| 894 | + function_values=fantasy_vals.expand(*batch_shape, -1), |
892 | 895 | num_samples=n_value,
|
| 896 | + eta=options.get("eta", 2.0), |
893 | 897 | replacement=True,
|
894 | 898 | ).permute(-1, *range(len(batch_shape)))
|
895 | 899 | resampled = fantasy_cands[idx]
|
@@ -979,18 +983,12 @@ def initialize_q_batch(
|
979 | 983 | return X[idcs], acq_vals[idcs]
|
980 | 984 |
|
981 | 985 | max_val, max_idx = torch.max(acq_vals, dim=0)
|
982 |
| - Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd |
983 |
| - etaZ = eta * Z |
984 |
| - weights = torch.exp(etaZ) |
985 |
| - while torch.isinf(weights).any(): |
986 |
| - etaZ *= 0.5 |
987 |
| - weights = torch.exp(etaZ) |
988 |
| - if batch_shape == torch.Size(): |
989 |
| - idcs = torch.multinomial(weights, n) |
990 |
| - else: |
991 |
| - idcs = batched_multinomial( |
992 |
| - weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n |
993 |
| - ).permute(-1, *range(len(batch_shape))) |
| 986 | + idcs = boltzmann_sample( |
| 987 | + acq_vals.permute(*range(1, len(batch_shape) + 1), 0), |
| 988 | + num_samples=n, |
| 989 | + eta=eta, |
| 990 | + ).permute(-1, *range(len(batch_shape))) |
| 991 | + |
994 | 992 | # make sure we get the maximum
|
995 | 993 | if max_idx not in idcs:
|
996 | 994 | idcs[-1] = max_idx
|
@@ -1239,133 +1237,6 @@ def sample_points_around_best(
|
1239 | 1237 | return perturbed_X
|
1240 | 1238 |
|
1241 | 1239 |
|
1242 |
| -def sample_truncated_normal_perturbations( |
1243 |
| - X: Tensor, |
1244 |
| - n_discrete_points: int, |
1245 |
| - sigma: float, |
1246 |
| - bounds: Tensor, |
1247 |
| - qmc: bool = True, |
1248 |
| -) -> Tensor: |
1249 |
| - r"""Sample points around `X`. |
1250 |
| -
|
1251 |
| - Sample perturbed points around `X` such that the added perturbations |
1252 |
| - are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d. |
1253 |
| -
|
1254 |
| - Args: |
1255 |
| - X: A `n x d`-dim tensor starting points. |
1256 |
| - n_discrete_points: The number of points to sample. |
1257 |
| - sigma: The standard deviation of the additive gaussian noise for |
1258 |
| - perturbing the points. |
1259 |
| - bounds: A `2 x d`-dim tensor containing the bounds. |
1260 |
| - qmc: A boolean indicating whether to use qmc. |
1261 |
| -
|
1262 |
| - Returns: |
1263 |
| - A `n_discrete_points x d`-dim tensor containing the sampled points. |
1264 |
| - """ |
1265 |
| - X = normalize(X, bounds=bounds) |
1266 |
| - d = X.shape[1] |
1267 |
| - # sample points from N(X_center, sigma^2 I), truncated to be within |
1268 |
| - # [0, 1]^d. |
1269 |
| - if X.shape[0] > 1: |
1270 |
| - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) |
1271 |
| - X = X[rand_indices] |
1272 |
| - if qmc: |
1273 |
| - std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device) |
1274 |
| - std_bounds[1] = 1 |
1275 |
| - u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1) |
1276 |
| - else: |
1277 |
| - u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device) |
1278 |
| - # compute bounds to sample from |
1279 |
| - a = -X |
1280 |
| - b = 1 - X |
1281 |
| - # compute z-score of bounds |
1282 |
| - alpha = a / sigma |
1283 |
| - beta = b / sigma |
1284 |
| - normal = Normal(0, 1) |
1285 |
| - cdf_alpha = normal.cdf(alpha) |
1286 |
| - # use inverse transform |
1287 |
| - perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma |
1288 |
| - # add perturbation and clip points that are still outside |
1289 |
| - perturbed_X = (X + perturbation).clamp(0.0, 1.0) |
1290 |
| - return unnormalize(perturbed_X, bounds=bounds) |
1291 |
| - |
1292 |
| - |
1293 |
| -def sample_perturbed_subset_dims( |
1294 |
| - X: Tensor, |
1295 |
| - bounds: Tensor, |
1296 |
| - n_discrete_points: int, |
1297 |
| - sigma: float = 1e-1, |
1298 |
| - qmc: bool = True, |
1299 |
| - prob_perturb: float | None = None, |
1300 |
| -) -> Tensor: |
1301 |
| - r"""Sample around `X` by perturbing a subset of the dimensions. |
1302 |
| -
|
1303 |
| - By default, dimensions are perturbed with probability equal to |
1304 |
| - `min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number |
1305 |
| - of dimensions can be beneificial. The perturbations are sampled |
1306 |
| - from N(0, sigma^2 I) and truncated to be within [0,1]^d. |
1307 |
| -
|
1308 |
| - Args: |
1309 |
| - X: A `n x d`-dim tensor starting points. `X` |
1310 |
| - must be normalized to be within `[0, 1]^d`. |
1311 |
| - bounds: The bounds to sample perturbed values from |
1312 |
| - n_discrete_points: The number of points to sample. |
1313 |
| - sigma: The standard deviation of the additive gaussian noise for |
1314 |
| - perturbing the points. |
1315 |
| - qmc: A boolean indicating whether to use qmc. |
1316 |
| - prob_perturb: The probability of perturbing each dimension. If omitted, |
1317 |
| - defaults to `min(20 / d, 1)`. |
1318 |
| -
|
1319 |
| - Returns: |
1320 |
| - A `n_discrete_points x d`-dim tensor containing the sampled points. |
1321 |
| -
|
1322 |
| - """ |
1323 |
| - if bounds.ndim != 2: |
1324 |
| - raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.") |
1325 |
| - elif X.ndim != 2: |
1326 |
| - raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.") |
1327 |
| - d = bounds.shape[-1] |
1328 |
| - if prob_perturb is None: |
1329 |
| - # Only perturb a subset of the features |
1330 |
| - prob_perturb = min(20.0 / d, 1.0) |
1331 |
| - |
1332 |
| - if X.shape[0] == 1: |
1333 |
| - X_cand = X.repeat(n_discrete_points, 1) |
1334 |
| - else: |
1335 |
| - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) |
1336 |
| - X_cand = X[rand_indices] |
1337 |
| - pert = sample_truncated_normal_perturbations( |
1338 |
| - X=X_cand, |
1339 |
| - n_discrete_points=n_discrete_points, |
1340 |
| - sigma=sigma, |
1341 |
| - bounds=bounds, |
1342 |
| - qmc=qmc, |
1343 |
| - ) |
1344 |
| - |
1345 |
| - # find cases where we are not perturbing any dimensions |
1346 |
| - mask = ( |
1347 |
| - torch.rand( |
1348 |
| - n_discrete_points, |
1349 |
| - d, |
1350 |
| - dtype=bounds.dtype, |
1351 |
| - device=bounds.device, |
1352 |
| - ) |
1353 |
| - <= prob_perturb |
1354 |
| - ) |
1355 |
| - ind = (~mask).all(dim=-1).nonzero() |
1356 |
| - # perturb `n_perturb` of the dimensions |
1357 |
| - n_perturb = ceil(d * prob_perturb) |
1358 |
| - perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device) |
1359 |
| - perturb_mask[:n_perturb].fill_(1) |
1360 |
| - # TODO: use batched `torch.randperm` when available: |
1361 |
| - # https://github.com/pytorch/pytorch/issues/42502 |
1362 |
| - for idx in ind: |
1363 |
| - mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)] |
1364 |
| - # Create candidate points |
1365 |
| - X_cand[mask] = pert[mask] |
1366 |
| - return X_cand |
1367 |
| - |
1368 |
| - |
1369 | 1240 | def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
|
1370 | 1241 | r"""Determine whether a given acquisition function is non-negative.
|
1371 | 1242 |
|
|
0 commit comments