Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def evaluate(self, X: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
kwargs: Additional keyword arguments. This includes the options for
optimization of the inner problem, i.e. `num_restarts`, `raw_samples`,
an `options` dictionary to be passed on to the optimization helpers, and
a `scipy_options` dictionary to be passed to `scipy.optimize.minimize`.
a `scipy_options` dictionary to be passed to `scipy.minimize`.

Returns:
A Tensor of shape `b`. For t-batch b, the q-KG value of the design
Expand Down
12 changes: 4 additions & 8 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def gen_candidates_scipy(

Optimizes an acquisition function starting from a set of initial candidates
using `scipy.optimize.minimize` via a numpy converter.
We use SLSQP, if constraints are present, and LBFGS-B otherwise.
As `scipy.optimize.minimize` does not support optimizating a batch of problems, we
treat optimizing a set of candidates as a single optimization problem by
summing together their acquisition values.

Args:
initial_conditions: Starting points for optimization, with shape
Expand All @@ -102,7 +98,7 @@ def gen_candidates_scipy(
`optimize_acqf()`. The constraints will later be passed to the scipy
solver.
options: Options used to control the optimization including "method"
and "maxiter". Select method for `scipy.optimize.minimize` using the
and "maxiter". Select method for `scipy.minimize` using the
"method" key. By default uses L-BFGS-B for box-constrained problems
and SLSQP if inequality or equality constraints are present. If
`with_grad=False`, then we use a two-point finite difference estimate
Expand Down Expand Up @@ -664,13 +660,13 @@ def _process_scipy_result(res: OptimizeResult, options: dict[str, Any]) -> None:
or "Iteration limit reached" in res.message
):
logger.info(
"`scipy.optimize.minimize` exited by reaching the iteration limit of "
"`scipy.minimize` exited by reaching the iteration limit of "
f"`maxiter: {options.get('maxiter')}`."
)
elif "EVALUATIONS EXCEEDS LIMIT" in res.message:
logger.info(
"`scipy.optimize.minimize` exited by reaching the function evaluation "
f"limit of `maxfun: {options.get('maxfun')}`."
"`scipy.minimize` exited by reaching the function evaluation limit of "
f"`maxfun: {options.get('maxfun')}`."
)
elif "Optimization timed out after" in res.message:
logger.info(res.message)
Expand Down
12 changes: 9 additions & 3 deletions botorch/models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ def __init__(
self.sample_shape = Size() if sample_shape is None else sample_shape
self.ensemble_as_batch = ensemble_as_batch

# NOTE circular import in pathwise/utils.py otherwise
from botorch.sampling.pathwise import draw_matheron_paths
# Import from the concrete implementation module so that test mocks
# (which patch the draw_matheron_paths function) are respected.
from botorch.sampling.pathwise.posterior_samplers import draw_matheron_paths

# Generate the Matheron path once during initialization
if seed is not None:
Expand Down Expand Up @@ -322,7 +323,12 @@ def forward(self, X: Tensor) -> Tensor:
return self._path(X).unsqueeze(-1)
elif isinstance(self.model, ModelList):
# For model list, stack the path outputs
return torch.stack(self._path(X), dim=-1)
path_outputs = self._path(X)
if len(path_outputs) == 0:
# Handle empty model list case by returning a tensor with shape (..., 0)
batch_shape = X.shape[:-1] # batch_shape x n
return torch.empty(*batch_shape, 0, dtype=X.dtype, device=X.device)
return torch.stack(path_outputs, dim=-1)
else:
# For multi-output models
return self._path(X.unsqueeze(-3)).transpose(-1, -2)
Expand Down
2 changes: 2 additions & 0 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def _apply_noise(
self,
X: Tensor,
mvn: MultivariateNormal,
num_outputs: int,
observation_noise: bool | Tensor,
) -> MultivariateNormal:
"""Adds the observation noise to the posterior.
Expand Down Expand Up @@ -937,6 +938,7 @@ def posterior(
mvn = self._apply_noise(
X=X_full,
mvn=mvn,
num_outputs=num_outputs,
observation_noise=observation_noise,
)
# If single-output, return the posterior of a single-output model
Expand Down
73 changes: 15 additions & 58 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from botorch.models.utils.assorted import get_task_value_remapping
from botorch.models.utils.gpytorch_modules import (
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
MIN_INFERRED_NOISE_LEVEL,
)
from botorch.posteriors.multitask import MultitaskGPPosterior
Expand All @@ -55,7 +56,6 @@
from gpytorch.kernels.index_kernel import IndexKernel
from gpytorch.kernels.multitask_kernel import MultitaskKernel
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.likelihoods.multitask_gaussian_likelihood import (
MultitaskGaussianLikelihood,
Expand Down Expand Up @@ -115,7 +115,6 @@ def __init__(
all_tasks: list[int] | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
input_transform: InputTransform | None = None,
validate_task_values: bool = True,
) -> None:
r"""Multi-Task GP model using an ICM kernel.

Expand Down Expand Up @@ -158,9 +157,6 @@ def __init__(
instantiation of the model.
input_transform: An input transform that is applied in the model's
forward pass.
validate_task_values: If True, validate that the task values supplied in the
input are expected tasks values. If false, unexpected task values
will be mapped to the first output_task if supplied.

Example:
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
Expand Down Expand Up @@ -193,7 +189,7 @@ def __init__(
"This is not allowed as it will lead to errors during model training."
)
all_tasks = all_tasks or all_tasks_inferred
self.num_tasks = len(all_tasks_inferred)
self.num_tasks = len(all_tasks)
if outcome_transform == DEFAULT:
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
if outcome_transform is not None:
Expand All @@ -212,20 +208,10 @@ def __init__(
self._output_tasks = output_tasks
self._num_outputs = len(output_tasks)

# TODO (T41270962): Support task-specific noise levels in likelihood
if likelihood is None:
if train_Yvar is None:
noise_prior = LogNormalPrior(loc=-4.0, scale=1.0)
likelihood = HadamardGaussianLikelihood(
num_tasks=self.num_tasks,
batch_shape=torch.Size(),
noise_prior=noise_prior,
noise_constraint=GreaterThan(
MIN_INFERRED_NOISE_LEVEL,
transform=None,
initial_value=noise_prior.mode,
),
task_feature_index=task_feature,
)
likelihood = get_gaussian_likelihood_with_lognormal_prior()
else:
likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))

Expand Down Expand Up @@ -263,60 +249,31 @@ def __init__(

self.covar_module = data_covar_module * task_covar_module
task_mapper = get_task_value_remapping(
observed_task_values=torch.tensor(
all_tasks_inferred, dtype=torch.long, device=train_X.device
),
all_task_values=torch.tensor(
sorted(all_tasks), dtype=torch.long, device=train_X.device
task_values=torch.tensor(
all_tasks, dtype=torch.long, device=train_X.device
),
dtype=train_X.dtype,
default_task_value=None if output_tasks is None else output_tasks[0],
)
self.register_buffer("_task_mapper", task_mapper)
self._expected_task_values = set(all_tasks_inferred)
self._expected_task_values = set(all_tasks)
if input_transform is not None:
self.input_transform = input_transform
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self._validate_task_values = validate_task_values
self.to(train_X)

def _map_tasks(self, task_values: Tensor) -> Tensor:
"""Map raw task values to the task indices used by the model.
"""Map task values to contiguous integers using the task mapper.

Args:
task_values: A tensor of task values.
task_values: A tensor of task indices to be mapped.

Returns:
A tensor of task indices with the same shape as the input
tensor.
A tensor of mapped task indices.
"""
long_task_values = task_values.long()
if self._validate_task_values:
if self._task_mapper is None:
if not (
torch.all(0 <= task_values)
and torch.all(task_values < self.num_tasks)
):
raise ValueError(
"Expected all task features in `X` to be between 0 and "
f"self.num_tasks - 1. Got {task_values}."
)
else:
unexpected_task_values = set(
long_task_values.unique().tolist()
).difference(self._expected_task_values)
if len(unexpected_task_values) > 0:
raise ValueError(
"Received invalid raw task values. Expected raw value to be in"
f" {self._expected_task_values}, but got unexpected task"
f" values: {unexpected_task_values}."
)
task_values = self._task_mapper[long_task_values]
elif self._task_mapper is not None:
task_values = self._task_mapper[long_task_values]

return task_values
if self._task_mapper is None:
return task_values.to(dtype=self.train_targets.dtype)
return self._task_mapper[task_values].to(dtype=self.train_targets.dtype)

def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
r"""Extracts features before task feature, task indices, and features after
Expand All @@ -330,7 +287,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
3-element tuple containing

- A `q x d` or `b x q x d` tensor with features before the task feature
- A `q` or `b x q x 1` tensor with mapped task indices
- A `q` or `b x q` tensor with mapped task indices
- A `q x d` or `b x q x d` tensor with features after the task feature
"""
batch_shape = x.shape[:-2]
Expand Down Expand Up @@ -370,7 +327,7 @@ def get_all_tasks(
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
task_feature = task_feature % (d + 1)
all_tasks = (
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
)
return all_tasks, task_feature, d

Expand Down
50 changes: 30 additions & 20 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,29 +406,39 @@ class fantasize(_Flag):


def get_task_value_remapping(
observed_task_values: Tensor,
all_task_values: Tensor,
dtype: torch.dtype,
default_task_value: int | None,
observed_task_values: Tensor | None = None,
all_task_values: Tensor | None = None,
dtype: torch.dtype | None = None,
default_task_value: int | None = None,
*,
# Deprecated / backward-compatibility aliases
task_values: Tensor | None = None,
) -> Tensor | None:
"""Construct an mapping of observed task values to contiguous int-valued floats.
"""Construct a mapping of observed task values to contiguous integers.

Args:
observed_task_values: A sorted long-valued tensor of task values.
all_task_values: A sorted long-valued tensor of task values.
dtype: The dtype of the model inputs (e.g. `X`), which the new
task values should have mapped to (e.g. float, double).
default_task_value: The default task value to use for missing task values.

Returns:
A tensor of shape `task_values.max() + 1` that maps task values
to new task values. The indexing operation `mapper[task_value]`
will produce a tensor of new task values, of the same shape as
the original. The elements of the `mapper` tensor that do not
appear in the original `task_values` are mapped to `nan`. The
return value will be `None`, when the task values are contiguous
integers starting from zero.
This function previously accepted the first argument as ``task_values``. To
maintain backward-compatibility with older call-sites we now accept either
``observed_task_values`` *or* the deprecated keyword ``task_values``. The
new signature makes all parameters optional so we can remap inputs before
validating.
"""

# Handle legacy keyword argument alias.
if observed_task_values is None and task_values is not None:
observed_task_values = task_values

# Basic validation after resolving aliases.
# Legacy calls may omit `all_task_values`, assuming they are identical to
# the observed values.
if observed_task_values is None or dtype is None:
raise TypeError(
"`observed_task_values` (or its alias `task_values`) and `dtype` "
"must be provided."
)

if all_task_values is None:
all_task_values = observed_task_values

if dtype not in (torch.float, torch.double):
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
task_range = torch.arange(
Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def scipy_minimize(
bounds: A dictionary mapping parameter names to lower and upper bounds.
callback: A callable taking `parameters` and an OptimizationResult as arguments.
x0: An optional initialization vector passed to scipy.optimize.minimize.
method: Solver type, passed along to scipy.optimize.minimize.
options: Dictionary of solver options, passed along to scipy.optimize.minimize.
method: Solver type, passed along to scipy.minimize.
options: Dictionary of solver options, passed along to scipy.minimize.
timeout_sec: Timeout in seconds to wait before aborting the optimization loop
if not converged (will return the best found solution thus far).

Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def fit_gpytorch_mll_scipy(
Responsible for setting the `grad` attributes of `parameters`. If no closure
is provided, one will be obtained by calling `get_loss_closure_with_grads`.
closure_kwargs: Keyword arguments passed to `closure`.
method: Solver type, passed along to scipy.optimize.minimize.
options: Dictionary of solver options, passed along to scipy.optimize.minimize.
method: Solver type, passed along to scipy.minimize.
options: Dictionary of solver options, passed along to scipy.minimize.
callback: Optional callback taking `parameters` and an OptimizationResult as its
sole arguments.
timeout_sec: Timeout in seconds after which to terminate the fitting loop
Expand Down
36 changes: 5 additions & 31 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,29 +603,7 @@ def optimize_acqf(
retry_on_optimization_warning: bool = True,
**ic_gen_kwargs: Any,
) -> tuple[Tensor, Tensor]:
r"""Optimize the acquisition function for a single or multiple joint candidates.

A high-level description (missing exceptions for special setups):

This function optimizes the acquisition function `acq_function` in two steps:

i) It will sample `raw_samples` random points using Sobol sampling in the bounds
`bounds` and pass on the "best" `num_restarts` many.
The default way to find these "best" is via `gen_batch_initial_conditions`
(deviating for some acq functions, see `get_ic_generator`),
which by default performs Boltzmann sampling on the acquisition function value
(The behavior of step (i) can be further controlled by specifying `ic_generator`
or `batch_initial_conditions`.)

ii) A batch of the `num_restarts` points (or joint sets of points)
with the highest acquisition values in the previous step are then further
optimized. This is by default done by LBFGS-B optimization, if no constraints are
present, and SLSQP, if constraints are present (can be changed to
other optmizers via `gen_candidates`).

While the optimization procedure runs on CPU by default for this function,
the acq_function can be implemented on GPU and simply move the inputs
to GPU internally.
r"""Generate a set of candidates via multi-start optimization.

Args:
acq_function: An AcquisitionFunction.
Expand All @@ -634,13 +612,10 @@ def optimize_acqf(
+inf, respectively).
q: The number of candidates.
num_restarts: The number of starting points for multistart acquisition
function optimization. Even though the name suggests this happens
sequentually, it is done in parallel (using batched evaluations)
for up to `options.batch_limit` candidates (by default completely parallel).
function optimization.
raw_samples: The number of samples for initialization. This is required
if `batch_initial_conditions` is not specified.
options: Options for both optimization, passed to `gen_candidates`,
and initialization, passed to the `ic_generator` via the `options` kwarg.
options: Options for candidate generation.
inequality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
Expand Down Expand Up @@ -685,9 +660,8 @@ def optimize_acqf(
acquisition values) given a tensor of initial conditions and an
acquisition function. Other common inputs include lower and upper bounds
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g., botorch.optim.optimize.gen_candidates_scipy
and botorch.generation.gen.gen_candidates_torch) for method-specific
inputs. Default: `gen_candidates_scipy`
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization for optimizing multiple joint candidates (q > 1).
acq_function_sequence: A list of acquisition functions to be optimized
Expand Down
Loading