|
22 | 22 | from botorch.acquisition.knowledge_gradient import qKnowledgeGradient |
23 | 23 | from botorch.exceptions import InputDataError, UnsupportedError |
24 | 24 | from botorch.exceptions.warnings import OptimizationWarning |
25 | | -from botorch.generation.gen import gen_candidates_scipy |
| 25 | +from botorch.generation.gen import gen_candidates_scipy, TGenCandidates |
26 | 26 | from botorch.logging import logger |
27 | 27 | from botorch.optim.initializers import ( |
28 | 28 | gen_batch_initial_conditions, |
@@ -64,6 +64,7 @@ def optimize_acqf( |
64 | 64 | post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, |
65 | 65 | batch_initial_conditions: Optional[Tensor] = None, |
66 | 66 | return_best_only: bool = True, |
| 67 | + gen_candidates: TGenCandidates = gen_candidates_scipy, |
67 | 68 | sequential: bool = False, |
68 | 69 | **kwargs: Any, |
69 | 70 | ) -> Tuple[Tensor, Tensor]: |
@@ -103,6 +104,8 @@ def optimize_acqf( |
103 | 104 | this if you do not want to use default initialization strategy. |
104 | 105 | return_best_only: If False, outputs the solutions corresponding to all |
105 | 106 | random restart initializations of the optimization. |
| 107 | + gen_candidates: A callable for generating candidates given initial |
| 108 | + conditions. Default: `gen_candidates_scipy` |
106 | 109 | sequential: If False, uses joint optimization, otherwise uses sequential |
107 | 110 | optimization. |
108 | 111 | kwargs: Additonal keyword arguments. |
@@ -258,23 +261,23 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]: |
258 | 261 | batched_ics = batch_initial_conditions.split(batch_limit) |
259 | 262 | opt_warnings = [] |
260 | 263 |
|
261 | | - scipy_kws = dict( |
262 | | - acquisition_function=acq_function, |
263 | | - lower_bounds=None if bounds[0].isinf().all() else bounds[0], |
264 | | - upper_bounds=None if bounds[1].isinf().all() else bounds[1], |
265 | | - options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, |
266 | | - inequality_constraints=inequality_constraints, |
267 | | - equality_constraints=equality_constraints, |
268 | | - nonlinear_inequality_constraints=nonlinear_inequality_constraints, |
269 | | - fixed_features=fixed_features, |
270 | | - ) |
| 264 | + gen_kwargs = { |
| 265 | + "acquisition_function": acq_function, |
| 266 | + "lower_bounds": None if bounds[0].isinf().all() else bounds[0], |
| 267 | + "upper_bounds": None if bounds[1].isinf().all() else bounds[1], |
| 268 | + "options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, |
| 269 | + "fixed_features": fixed_features, |
| 270 | + "inequality_constraints": inequality_constraints, |
| 271 | + "equality_constraints": equality_constraints, |
| 272 | + "nonlinear_inequality_constraints": nonlinear_inequality_constraints, |
| 273 | + } |
271 | 274 |
|
272 | 275 | for i, batched_ics_ in enumerate(batched_ics): |
273 | 276 | # optimize using random restart optimization |
274 | 277 | with warnings.catch_warnings(record=True) as ws: |
275 | 278 | warnings.simplefilter("always", category=OptimizationWarning) |
276 | | - batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy( |
277 | | - initial_conditions=batched_ics_, **scipy_kws |
| 279 | + batch_candidates_curr, batch_acq_values_curr = gen_candidates( |
| 280 | + initial_conditions=batched_ics_, **gen_kwargs |
278 | 281 | ) |
279 | 282 | opt_warnings += ws |
280 | 283 | batch_candidates_list.append(batch_candidates_curr) |
|
0 commit comments