@@ -324,6 +324,37 @@ def test_gen_candidates_torch_timeout_behavior(self):
324
324
self .assertFalse (any (issubclass (w .category , OptimizationWarning ) for w in ws ))
325
325
self .assertTrue ("Optimization timed out" in logs .output [- 1 ])
326
326
327
+ def test_gen_candidates_torch_optimizer_with_optimizer_args (self ):
328
+ """Test that Adam optimizer is created with the correct learning rate."""
329
+ self ._setUp (double = False )
330
+ qEI = qExpectedImprovement (self .model , best_f = self .f_best )
331
+
332
+ # Create a mock optimizer class
333
+ mock_optimizer_class = mock .MagicMock ()
334
+ mock_optimizer_instance = mock .MagicMock ()
335
+ mock_optimizer_class .return_value = mock_optimizer_instance
336
+
337
+ gen_candidates_torch (
338
+ initial_conditions = self .initial_conditions ,
339
+ acquisition_function = qEI ,
340
+ lower_bounds = 0 ,
341
+ upper_bounds = 1 ,
342
+ optimizer = mock_optimizer_class , # Pass the mock optimizer directly
343
+ options = {
344
+ "optimizer_options" : {"lr" : 0.02 , "weight_decay" : 1e-5 },
345
+ "stopping_criterion_options" : {"maxiter" : 1 },
346
+ },
347
+ )
348
+
349
+ # Verify that the optimizer was called with the correct arguments
350
+ mock_optimizer_class .assert_called_once ()
351
+ call_args = mock_optimizer_class .call_args
352
+ # Check that params argument is present
353
+ self .assertIn ("params" , call_args .kwargs )
354
+ # Check optimizer options
355
+ self .assertEqual (call_args .kwargs ["lr" ], 0.02 )
356
+ self .assertEqual (call_args .kwargs ["weight_decay" ], 1e-5 )
357
+
327
358
def test_gen_candidates_scipy_warns_opt_no_res (self ):
328
359
ckwargs = {"dtype" : torch .float , "device" : self .device }
329
360
0 commit comments