@@ -96,7 +96,7 @@ def __str__(self):
96
96
return output
97
97
98
98
99
- def solve_main (objfun , x0 , args , xl , xu , npt , rhobeg , rhoend , maxfun , nruns_so_far , nf_so_far , nx_so_far , nsamples , params ,
99
+ def solve_main (objfun , x0 , args , xl , xu , projections , npt , rhobeg , rhoend , maxfun , nruns_so_far , nf_so_far , nx_so_far , nsamples , params ,
100
100
diagnostic_info , scaling_changes , f0_avg_old = None , f0_nsamples_old = None , do_logging = True , print_progress = False ):
101
101
# Evaluate at x0 (keep nf, nx correct and check for f small)
102
102
if f0_avg_old is None :
@@ -144,7 +144,7 @@ def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_f
144
144
nx = nx_so_far
145
145
146
146
# Initialise controller
147
- control = Controller (objfun , x0 , args , f0_avg , num_samples_run , xl , xu , npt , rhobeg , rhoend , nf , nx , maxfun , params , scaling_changes , do_logging = do_logging )
147
+ control = Controller (objfun , x0 , args , f0_avg , num_samples_run , xl , xu , projections , npt , rhobeg , rhoend , nf , nx , maxfun , params , scaling_changes , do_logging = do_logging )
148
148
149
149
# Initialise interpolation set
150
150
number_of_samples = max (nsamples (control .delta , control .rho , 0 , nruns_so_far ), 1 )
@@ -665,7 +665,7 @@ def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_f
665
665
return x , f , gradmin , hessmin , nsamples , control .nf , control .nx , nruns_so_far , exit_info , diagnostic_info
666
666
667
667
668
- def solve (objfun , x0 , args = (), bounds = None , npt = None , rhobeg = None , rhoend = 1e-8 , maxfun = None , nsamples = None , user_params = None ,
668
+ def solve (objfun , x0 , args = (), bounds = None , projections = None , npt = None , rhobeg = None , rhoend = 1e-8 , maxfun = None , nsamples = None , user_params = None ,
669
669
objfun_has_noise = False , seek_global_minimum = False , scaling_within_bounds = False , do_logging = True , print_progress = False ):
670
670
n = len (x0 )
671
671
if type (x0 ) == list :
@@ -694,7 +694,11 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
694
694
if (xl is None or xu is None ) and scaling_within_bounds :
695
695
scaling_within_bounds = False
696
696
warnings .warn ("Ignoring scaling_within_bounds=True for unconstrained problem/1-sided bounds" , RuntimeWarning )
697
-
697
+
698
+ if (projections is not None ) and scaling_within_bounds :
699
+ scaling_within_bounds = False
700
+ warnings .warn ("Ignoring scaling_within_bounds=True for problems with projections given" , RuntimeWarning )
701
+
698
702
exit_info = None
699
703
if seek_global_minimum and (xl is None or xu is None ):
700
704
exit_info = ExitInformation (EXIT_INPUT_ERROR , "If seeking global minimum, must specify upper and lower bounds" )
@@ -761,6 +765,9 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
761
765
if exit_info is None and np .min (xu - xl ) < 2.0 * rhobeg :
762
766
exit_info = ExitInformation (EXIT_INPUT_ERROR , "gap between lower and upper must be at least 2*rhobeg" )
763
767
768
+ if exit_info is None and projections is not None and type (projections ) != list :
769
+ exit_info = ExitInformation (EXIT_INPUT_ERROR , "projections must be a list of functions" )
770
+
764
771
if maxfun <= npt :
765
772
warnings .warn ("maxfun <= npt: Are you sure your budget is large enough?" , RuntimeWarning )
766
773
@@ -792,12 +799,12 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
792
799
return results
793
800
794
801
# Enforce lower & upper bounds on x0
795
- idx = (x0 <= xl )
802
+ idx = (x0 < xl )
796
803
if np .any (idx ):
797
804
warnings .warn ("x0 below lower bound, adjusting" , RuntimeWarning )
798
805
x0 [idx ] = xl [idx ]
799
806
800
- idx = (x0 >= xu )
807
+ idx = (x0 > xu )
801
808
if np .any (idx ):
802
809
warnings .warn ("x0 above upper bound, adjusting" , RuntimeWarning )
803
810
x0 [idx ] = xu [idx ]
@@ -808,7 +815,7 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
808
815
nf = 0
809
816
nx = 0
810
817
xmin , fmin , gradmin , hessmin , nsamples_min , nf , nx , nruns , exit_info , diagnostic_info = \
811
- solve_main (objfun , x0 , args , xl , xu , npt , rhobeg , rhoend , maxfun , nruns , nf , nx , nsamples , params ,
818
+ solve_main (objfun , x0 , args , xl , xu , projections , npt , rhobeg , rhoend , maxfun , nruns , nf , nx , nsamples , params ,
812
819
diagnostic_info , scaling_changes , do_logging = do_logging , print_progress = print_progress )
813
820
814
821
# Hard restarts loop
@@ -829,11 +836,11 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
829
836
% (fmin , nf , _rhobeg , _rhoend ))
830
837
if params ("restarts.hard.use_old_fk" ):
831
838
xmin2 , fmin2 , gradmin2 , hessmin2 , nsamples2 , nf , nx , nruns , exit_info , diagnostic_info = \
832
- solve_main (objfun , xmin , args , xl , xu , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
839
+ solve_main (objfun , xmin , args , xl , xu , projections , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
833
840
diagnostic_info , scaling_changes , f0_avg_old = fmin , f0_nsamples_old = nsamples_min , do_logging = do_logging , print_progress = print_progress )
834
841
else :
835
842
xmin2 , fmin2 , gradmin2 , hessmin2 , nsamples2 , nf , nx , nruns , exit_info , diagnostic_info = \
836
- solve_main (objfun , xmin , args , xl , xu , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
843
+ solve_main (objfun , xmin , args , xl , xu , projections , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
837
844
diagnostic_info , scaling_changes , do_logging = do_logging , print_progress = print_progress )
838
845
839
846
if fmin2 < fmin or np .isnan (fmin ):
0 commit comments