diff --git a/lambench/tasks/calculator/calculator_tasks.yml b/lambench/tasks/calculator/calculator_tasks.yml index 0b62b7b..718e3b3 100644 --- a/lambench/tasks/calculator/calculator_tasks.yml +++ b/lambench/tasks/calculator/calculator_tasks.yml @@ -12,6 +12,7 @@ inference_efficiency: test_data: /bohr/lambench-efficiency-rg7a/v3/efficiency calculator_params: warmup_ratio: 0.1 + natoms_upper_limit: 850 torsionnet: test_data: /bohr/lambench-torsionnet-e4sc/v2/torsionnet500_wB97m calculator_params: null diff --git a/lambench/tasks/calculator/inference_efficiency/inference_efficiency.py b/lambench/tasks/calculator/inference_efficiency/inference_efficiency.py index ff8871a..3cd52b4 100644 --- a/lambench/tasks/calculator/inference_efficiency/inference_efficiency.py +++ b/lambench/tasks/calculator/inference_efficiency/inference_efficiency.py @@ -16,7 +16,7 @@ def run_inference( - model: ASEModel, test_data: Path, warmup_ratio: float + model: ASEModel, test_data: Path, warmup_ratio: float, natoms_upper_limit: int ) -> dict[str, dict[str, float]]: """ Inference for all trajectories, return average time and success rate for each system. @@ -26,7 +26,9 @@ def run_inference( for traj in trajs: system_name = traj.name try: - system_result = run_one_inference(model, traj, warmup_ratio) + system_result = run_one_inference( + model, traj, warmup_ratio, natoms_upper_limit + ) average_time = system_result["average_time"] std_time = system_result["std_time"] success_rate = system_result["success_rate"] @@ -52,6 +54,7 @@ def run_one_inference( model: ASEModel, test_traj: Path, warmup_ratio: float, + natoms_upper_limit: int, ) -> dict[str, float]: """ Infer for one trajectory, return averaged time and success rate, starting timing at warmup_ratio. @@ -65,7 +68,7 @@ def run_one_inference( efficiency = [] for i, atoms in enumerate(test_atoms): # find maximum allowed natoms - max_natoms = binary_search_max_natoms(model, atoms) + max_natoms = binary_search_max_natoms(model, atoms, natoms_upper_limit) # on-the-fly expand atoms scaling_factor = np.int32(np.floor(max_natoms / len(atoms))) while 1 in find_even_factors(scaling_factor) and scaling_factor > 1: