@@ -532,7 +532,9 @@ def gen_candidates_torch(
532532        optimizer (Optimizer): The pytorch optimizer to use to perform 
533533            candidate search. 
534534        options: Options used to control the optimization. Includes 
535-             maxiter: Maximum number of iterations 
535+             optimizer_options: Dict of additional options to pass to the optimizer 
536+                 (e.g. lr, weight_decay) 
537+             stopping_criterion_options: Dict of options for the stopping criterion. 
536538        callback: A callback function accepting the current iteration, loss, 
537539            and gradients as arguments. This function is executed after computing 
538540            the loss and gradients, but before calling the optimizer. 
@@ -559,11 +561,11 @@ def gen_candidates_torch(
559561        >>>     qEI, bounds, q=3, num_restarts=25, raw_samples=500 
560562        >>> ) 
561563        >>> batch_candidates, batch_acq_values = gen_candidates_torch( 
562-                  initial_conditions=Xinit, 
563-                  acquisition_function=qEI, 
564-                  lower_bounds=bounds[0], 
565-                  upper_bounds=bounds[1], 
566-              ) 
564+               initial_conditions=Xinit, 
565+               acquisition_function=qEI, 
566+               lower_bounds=bounds[0], 
567+               upper_bounds=bounds[1], 
568+           ) 
567569    """ 
568570    start_time  =  time .monotonic ()
569571    options  =  options  or  {}
@@ -580,11 +582,17 @@ def gen_candidates_torch(
580582            [i  for  i  in  range (clamped_candidates .shape [- 1 ]) if  i  not  in   fixed_features ],
581583        ]
582584    clamped_candidates  =  clamped_candidates .requires_grad_ (True )
583-     _optimizer  =  optimizer (params = [clamped_candidates ], lr = options .get ("lr" , 0.025 ))
585+ 
586+     # Extract optimizer-specific options from the options dict 
587+     optimizer_options  =  options .pop ("optimizer_options" , {})
588+     stopping_criterion_options  =  options .pop ("stopping_criterion_options" , {})
589+ 
590+     optimizer_options ["lr" ] =  optimizer_options .get ("lr" , 0.025 )
591+     _optimizer  =  optimizer (params = [clamped_candidates ], ** optimizer_options )
584592
585593    i  =  0 
586594    stop  =  False 
587-     stopping_criterion  =  ExpMAStoppingCriterion (** options )
595+     stopping_criterion  =  ExpMAStoppingCriterion (** stopping_criterion_options )
588596    while  not  stop :
589597        i  +=  1 
590598        with  torch .no_grad ():
0 commit comments