Skip to content

Commit 90e16b6

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Structuring arguments in gen_candidates_torch (#3019)
Summary: Structuring the optimizer and stopping_criterion arguments in gen_candidates_torch. Fixes #2994. Differential Revision: D82839737
1 parent fa3749c commit 90e16b6

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

botorch/generation/gen.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,9 @@ def gen_candidates_torch(
532532
optimizer (Optimizer): The pytorch optimizer to use to perform
533533
candidate search.
534534
options: Options used to control the optimization. Includes
535-
maxiter: Maximum number of iterations
535+
optimizer_options: Dict of additional options to pass to the optimizer
536+
(e.g. lr, weight_decay)
537+
stopping_criterion_options: Dict of options for the stopping criterion.
536538
callback: A callback function accepting the current iteration, loss,
537539
and gradients as arguments. This function is executed after computing
538540
the loss and gradients, but before calling the optimizer.
@@ -580,11 +582,17 @@ def gen_candidates_torch(
580582
[i for i in range(clamped_candidates.shape[-1]) if i not in fixed_features],
581583
]
582584
clamped_candidates = clamped_candidates.requires_grad_(True)
583-
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))
585+
586+
# Extract optimizer-specific options from the options dict
587+
optimizer_options = options.pop("optimizer_options", {})
588+
stopping_criterion_options = options.pop("stopping_criterion_options", {})
589+
590+
optimizer_options["lr"] = optimizer_options.get("lr", 0.025)
591+
_optimizer = optimizer(params=[clamped_candidates], **optimizer_options)
584592

585593
i = 0
586594
stop = False
587-
stopping_criterion = ExpMAStoppingCriterion(**options)
595+
stopping_criterion = ExpMAStoppingCriterion(**stopping_criterion_options)
588596
while not stop:
589597
i += 1
590598
with torch.no_grad():

test/generation/test_gen.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,37 @@ def test_gen_candidates_torch_timeout_behavior(self):
324324
self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws))
325325
self.assertTrue("Optimization timed out" in logs.output[-1])
326326

327+
def test_gen_candidates_torch_optimizer_with_optimizer_args(self):
328+
"""Test that Adam optimizer is created with the correct learning rate."""
329+
self._setUp(double=False)
330+
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
331+
332+
# Create a mock optimizer class
333+
mock_optimizer_class = mock.MagicMock()
334+
mock_optimizer_instance = mock.MagicMock()
335+
mock_optimizer_class.return_value = mock_optimizer_instance
336+
337+
gen_candidates_torch(
338+
initial_conditions=self.initial_conditions,
339+
acquisition_function=qEI,
340+
lower_bounds=0,
341+
upper_bounds=1,
342+
optimizer=mock_optimizer_class, # Pass the mock optimizer directly
343+
options={
344+
"optimizer_options": {"lr": 0.02, "weight_decay": 1e-5},
345+
"stopping_criterion_options": {"maxiter": 1},
346+
},
347+
)
348+
349+
# Verify that the optimizer was called with the correct arguments
350+
mock_optimizer_class.assert_called_once()
351+
call_args = mock_optimizer_class.call_args
352+
# Check that params argument is present
353+
self.assertIn("params", call_args.kwargs)
354+
# Check optimizer options
355+
self.assertEqual(call_args.kwargs["lr"], 0.02)
356+
self.assertEqual(call_args.kwargs["weight_decay"], 1e-5)
357+
327358
def test_gen_candidates_scipy_warns_opt_no_res(self):
328359
ckwargs = {"dtype": torch.float, "device": self.device}
329360

0 commit comments

Comments
 (0)