Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,9 @@ def gen_candidates_torch(
optimizer (Optimizer): The pytorch optimizer to use to perform
candidate search.
options: Options used to control the optimization. Includes
maxiter: Maximum number of iterations
optimizer_options: Dict of additional options to pass to the optimizer
(e.g. lr, weight_decay)
stopping_criterion_options: Dict of options for the stopping criterion.
callback: A callback function accepting the current iteration, loss,
and gradients as arguments. This function is executed after computing
the loss and gradients, but before calling the optimizer.
Expand Down Expand Up @@ -580,11 +582,17 @@ def gen_candidates_torch(
[i for i in range(clamped_candidates.shape[-1]) if i not in fixed_features],
]
clamped_candidates = clamped_candidates.requires_grad_(True)
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))

# Extract optimizer-specific options from the options dict
optimizer_options = options.pop("optimizer_options", {})
stopping_criterion_options = options.pop("stopping_criterion_options", {})

optimizer_options["lr"] = optimizer_options.get("lr", 0.025)
_optimizer = optimizer(params=[clamped_candidates], **optimizer_options)

i = 0
stop = False
stopping_criterion = ExpMAStoppingCriterion(**options)
stopping_criterion = ExpMAStoppingCriterion(**stopping_criterion_options)
while not stop:
i += 1
with torch.no_grad():
Expand Down
31 changes: 31 additions & 0 deletions test/generation/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,37 @@ def test_gen_candidates_torch_timeout_behavior(self):
self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws))
self.assertTrue("Optimization timed out" in logs.output[-1])

def test_gen_candidates_torch_optimizer_with_optimizer_args(self):
"""Test that Adam optimizer is created with the correct learning rate."""
self._setUp(double=False)
qEI = qExpectedImprovement(self.model, best_f=self.f_best)

# Create a mock optimizer class
mock_optimizer_class = mock.MagicMock()
mock_optimizer_instance = mock.MagicMock()
mock_optimizer_class.return_value = mock_optimizer_instance

gen_candidates_torch(
initial_conditions=self.initial_conditions,
acquisition_function=qEI,
lower_bounds=0,
upper_bounds=1,
optimizer=mock_optimizer_class, # Pass the mock optimizer directly
options={
"optimizer_options": {"lr": 0.02, "weight_decay": 1e-5},
"stopping_criterion_options": {"maxiter": 1},
},
)

# Verify that the optimizer was called with the correct arguments
mock_optimizer_class.assert_called_once()
call_args = mock_optimizer_class.call_args
# Check that params argument is present
self.assertIn("params", call_args.kwargs)
# Check optimizer options
self.assertEqual(call_args.kwargs["lr"], 0.02)
self.assertEqual(call_args.kwargs["weight_decay"], 1e-5)

def test_gen_candidates_scipy_warns_opt_no_res(self):
ckwargs = {"dtype": torch.float, "device": self.device}

Expand Down
Loading