1616
1717
1818def run_inference (
19- model : ASEModel , test_data : Path , warmup_ratio : float
19+ model : ASEModel , test_data : Path , warmup_ratio : float , natoms_upper_limit : int
2020) -> dict [str , dict [str , float ]]:
2121 """
2222 Inference for all trajectories, return average time and success rate for each system.
@@ -26,7 +26,9 @@ def run_inference(
2626 for traj in trajs :
2727 system_name = traj .name
2828 try :
29- system_result = run_one_inference (model , traj , warmup_ratio )
29+ system_result = run_one_inference (
30+ model , traj , warmup_ratio , natoms_upper_limit
31+ )
3032 average_time = system_result ["average_time" ]
3133 std_time = system_result ["std_time" ]
3234 success_rate = system_result ["success_rate" ]
@@ -52,6 +54,7 @@ def run_one_inference(
5254 model : ASEModel ,
5355 test_traj : Path ,
5456 warmup_ratio : float ,
57+ natoms_upper_limit : int ,
5558) -> dict [str , float ]:
5659 """
5760 Infer for one trajectory, return averaged time and success rate, starting timing at warmup_ratio.
@@ -65,7 +68,7 @@ def run_one_inference(
6568 efficiency = []
6669 for i , atoms in enumerate (test_atoms ):
6770 # find maximum allowed natoms
68- max_natoms = binary_search_max_natoms (model , atoms )
71+ max_natoms = binary_search_max_natoms (model , atoms , natoms_upper_limit )
6972 # on-the-fly expand atoms
7073 scaling_factor = np .int32 (np .floor (max_natoms / len (atoms )))
7174 while 1 in find_even_factors (scaling_factor ) and scaling_factor > 1 :
0 commit comments