Skip to content

Commit f758454

Browse files
authored
Merge pull request #264 from deepmodeling/feat/redesign-efficiency-tests
Feat: Add new efficiency test
2 parents 9228603 + c7e4b9c commit f758454

File tree

5 files changed

+172
-22
lines changed

5 files changed

+172
-22
lines changed

lambench/metrics/vishelper/metrics_calculations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def calculate_stability_results(self) -> dict[str, float]:
167167
}
168168

169169
stability_results = pd.DataFrame.from_dict(stability_results, orient="index")
170-
stability_results = stability_results.applymap(
170+
stability_results = stability_results.map(
171171
lambda cell: self._calculate_instability_error(cell)
172172
)
173173
# average over all systems

lambench/tasks/calculator/calculator_tasks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ phonon_mdr:
99
calculator_params:
1010
distance: 0.01
1111
inference_efficiency:
12-
test_data: /bohr/lambench-efficiency-rg7a/v2/efficiency
12+
test_data: /bohr/lambench-efficiency-rg7a/v3/efficiency
1313
calculator_params:
14-
warmup_ratio: 0.2
14+
warmup_ratio: 0.1
1515
torsionnet:
1616
test_data: /bohr/lambench-torsionnet-e4sc/v2/torsionnet500_wB97m
1717
calculator_params: null
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from ase.atoms import Atoms
2+
from lambench.models.ase_models import ASEModel
3+
import numpy as np
4+
import math
5+
6+
7+
def get_efv(atoms: Atoms) -> tuple[float, np.ndarray, np.ndarray]:
8+
"""
9+
Perform force field prediction for one system, return energy, forces and stress.
10+
"""
11+
e = atoms.get_potential_energy()
12+
f = atoms.get_forces()
13+
stress = atoms.get_stress()
14+
v = (
15+
-np.array(
16+
[
17+
[stress[0], stress[5], stress[4]],
18+
[stress[5], stress[1], stress[3]],
19+
[stress[4], stress[3], stress[2]],
20+
]
21+
)
22+
* atoms.get_volume()
23+
)
24+
return e, f, v
25+
26+
27+
def catch_oom_error(atoms: Atoms) -> bool:
28+
"""
29+
Catch OOM error when running inference.
30+
"""
31+
try:
32+
get_efv(atoms)
33+
return False
34+
except Exception as e:
35+
if "out of memory" in str(e) or "OOM" in str(e):
36+
return True
37+
else:
38+
return False
39+
40+
41+
def get_divisors(num: int) -> list[int]:
42+
divisors = set()
43+
for i in range(1, int(math.isqrt(num)) + 1):
44+
if num % i == 0:
45+
divisors.add(i)
46+
divisors.add(num // i)
47+
return sorted(divisors)
48+
49+
50+
def find_even_factors(num: int) -> tuple[int, int, int]:
51+
"""
52+
Find three factors of a number that are as evenly distributed as possible.
53+
The function returns a tuple of three factors (a, b, c) such that a * b * c = num.
54+
The factors are sorted in ascending order (a <= b <= c).
55+
"""
56+
divisors = get_divisors(num)
57+
best = None
58+
min_spread = float("inf")
59+
60+
for a in divisors:
61+
num_div_a = num // a
62+
divisors_b = get_divisors(num_div_a)
63+
64+
# Since a <= b <= c, no need to consider b < a
65+
for b in divisors_b:
66+
if b < a:
67+
continue
68+
c = num_div_a // b
69+
if a * b * c == num:
70+
factors = [a, b, c]
71+
spread = max(factors) - min(factors)
72+
if spread < min_spread:
73+
min_spread = spread
74+
best = (a, b, c)
75+
if spread == 0: # Perfect distribution found
76+
return best
77+
return best
78+
79+
80+
def binary_search_max_natoms(
81+
model: ASEModel, atoms: Atoms, upper_limit: int = 1000, max_iterations: int = 15
82+
) -> int:
83+
"""
84+
Binary search for the maximum number of atoms that can be processed by the model.
85+
86+
"""
87+
low, high, iteration = 1, upper_limit, 0
88+
while low < high and iteration < max_iterations:
89+
mid = (low + high + 1) // 2
90+
scaling_factor = np.int32(np.ceil(mid / len(atoms)))
91+
scaled_atoms = atoms.copy()
92+
a, b, c = find_even_factors(scaling_factor)
93+
scaled_atoms = scaled_atoms.repeat((a, b, c))
94+
scaled_atoms.calc = model.calc
95+
if catch_oom_error(scaled_atoms):
96+
high = mid - 1
97+
else:
98+
low = mid
99+
iteration += 1
100+
return low

lambench/tasks/calculator/inference_efficiency/inference_efficiency.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from lambench.models.ase_models import ASEModel
2+
from lambench.tasks.calculator.inference_efficiency.efficiency_utils import (
3+
binary_search_max_natoms,
4+
get_efv,
5+
find_even_factors,
6+
)
27
from ase.io import read
3-
from ase.atoms import Atoms
48
import logging
59
import time
610
import numpy as np
@@ -11,23 +15,6 @@
1115
)
1216

1317

14-
def get_efv(atoms: Atoms) -> tuple[float, np.ndarray]:
15-
e = atoms.get_potential_energy()
16-
f = atoms.get_forces()
17-
stress = atoms.get_stress()
18-
v = (
19-
-np.array(
20-
[
21-
[stress[0], stress[5], stress[4]],
22-
[stress[5], stress[1], stress[3]],
23-
[stress[4], stress[3], stress[2]],
24-
]
25-
)
26-
* atoms.get_volume()
27-
)
28-
return e, f, v
29-
30-
3118
def run_inference(
3219
model: ASEModel, test_data: Path, warmup_ratio: float
3320
) -> dict[str, dict[str, float]]:
@@ -62,7 +49,9 @@ def run_inference(
6249

6350

6451
def run_one_inference(
65-
model: ASEModel, test_traj: Path, warmup_ratio: float
52+
model: ASEModel,
53+
test_traj: Path,
54+
warmup_ratio: float,
6655
) -> dict[str, float]:
6756
"""
6857
Infer for one trajectory, return averaged time and success rate, starting timing at warmup_ratio.
@@ -75,6 +64,14 @@ def run_one_inference(
7564

7665
efficiency = []
7766
for i, atoms in enumerate(test_atoms):
67+
# find maximum allowed natoms
68+
max_natoms = binary_search_max_natoms(model, atoms)
69+
# on-the-fly expand atoms
70+
scaling_factor = np.int32(np.floor(max_natoms / len(atoms)))
71+
while 1 in find_even_factors(scaling_factor) and scaling_factor > 1:
72+
scaling_factor -= 1
73+
a, b, c = find_even_factors(scaling_factor)
74+
atoms = atoms.repeat((a, b, c))
7875
atoms.calc = model.calc
7976
n_atoms = len(atoms)
8077
start = time.time()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from lambench.tasks.calculator.inference_efficiency.efficiency_utils import (
2+
find_even_factors,
3+
binary_search_max_natoms,
4+
)
5+
import pytest
6+
import numpy as np
7+
from ase.atoms import Atoms
8+
from unittest.mock import MagicMock
9+
10+
OOM_TEST_ATOM = Atoms(
11+
symbols="Mg",
12+
pbc=True,
13+
cell=[
14+
[-2.244256, -2.244256, 0.0],
15+
[-2.244256, 0.0, -2.244256],
16+
[0.0, -2.244256, -2.244256],
17+
],
18+
positions=[
19+
[0, 0, 0],
20+
],
21+
) # mp-1056702
22+
23+
24+
@pytest.mark.parametrize(
25+
"num, expected",
26+
[
27+
(27, (3, 3, 3)), # Perfect cube
28+
(13, (1, 1, 13)), # Prime number
29+
(16, (2, 2, 4)), # Even number
30+
(728, (7, 8, 13)), # Large number
31+
],
32+
)
33+
def test_find_even_factors(num, expected):
34+
result = find_even_factors(num)
35+
assert result == expected, f"Expected {expected}, got {result}"
36+
37+
38+
@pytest.mark.parametrize(
39+
"threshold, max_natoms",
40+
[(1999, 1000), (247, 247), (121, 121), (100, 100), (38, 38), (31, 31)],
41+
)
42+
def test_binary_search_max_natoms(threshold, max_natoms):
43+
def mock_get_potential_energy(atoms=None):
44+
if len(atoms) > threshold:
45+
raise MemoryError("OOM: Too many atoms!")
46+
return np.random.rand()
47+
48+
mock_model = MagicMock()
49+
mock_model.calc = MagicMock()
50+
mock_model.calc.get_potential_energy.side_effect = mock_get_potential_energy
51+
52+
result = binary_search_max_natoms(mock_model, OOM_TEST_ATOM)
53+
assert result == max_natoms, f"Expected {max_natoms}, got {result}"

0 commit comments

Comments
 (0)