-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add benchmark scripts #173
base: main
Are you sure you want to change the base?
Conversation
Okay, here we go again. Everyone happy? Should we mention this in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks solid.
Yes I would link in the CONTRIBUTING and also in PR checklist
https://github.com/lab-cosmo/torch-pme/blob/main/.github/pull_request_template.md
We should also maybe add a section in the docs for people who are interested. I can do this.
We could even run this as a gallery example - to verify if we break something and also to have this nice looking code somewhere on the web page.
``benchmark``: performance sanity checking | ||
========================================== | ||
|
||
This provides a basic benchmarking script intended to catch major performance regressions (or improvements!). It emulates a "training"-style workload of computing energy and forces for a number of similar systems with pre-computed neighborlists and settings. In particular, we run PME forward and backward calculations for supercells of cubic CsCl crystals, in different sizes. Results are stored as a time-stamped ``.yaml`` file, together with some basic system and version information (which is just the output of ``git`` for now). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We run PME or P3M? I think the latter is preferred.
You can take the version of torch-pme. Thanks to setuptools-scm the version will contain the version and the commit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could just for consistency in the project use a .xyz
. But up to you.
@@ -0,0 +1,394 @@ | |||
#!/usr/bin/env python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#!/usr/bin/env python | |
#!/usr/bin/env python3 |
devices = [] | ||
if torch.cuda.is_available(): | ||
devices.append("cuda") | ||
|
||
# run CUDA first! | ||
devices.append("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
devices = [] | |
if torch.cuda.is_available(): | |
devices.append("cuda") | |
# run CUDA first! | |
devices.append("cpu") | |
devices = ["cpu"] | |
# run CUDA first! | |
if torch.cuda.is_available(): | |
devices.insert(0, "cuda") |
# tight settings | ||
cutoff_tight = half_cell | ||
mesh_spacing_tight = atomic_smearing / 8.0 | ||
lr_wavelength_tight = atomic_smearing / 2.0 | ||
|
||
# light settings, roughly 4 times less computational costs | ||
cutoff_light = cutoff_tight / 2.0 | ||
mesh_spacing_light = mesh_spacing_tight * 2.0 | ||
lr_wavelength_light = lr_wavelength_tight * 2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use the tuning code with 10^-2 and maybe 10^-4. Maybe not now since the code seems still a bit fragile.... @GardevoirX
cwd = os.path.dirname(os.path.abspath(__file__)) | ||
try: | ||
torch_pme_commit = subprocess.check_output( | ||
["git", "log", "--oneline", "-1"], cwd=cwd | ||
).decode("utf-8") | ||
torch_pme_status = subprocess.check_output( | ||
["git", "status", "--porcelain"], cwd=cwd | ||
).decode("utf-8") | ||
except subprocess.CalledProcessError: | ||
torch_pme_commit = "not found" | ||
torch_pme_status = "not found" | ||
|
||
version = { | ||
"torch": str(torch.__version__), | ||
"torch-pme-commit": torch_pme_commit, | ||
"torch-pme-status": torch_pme_status, | ||
"torch-pme-version": torchpme.__version__, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be enough
cwd = os.path.dirname(os.path.abspath(__file__)) | |
try: | |
torch_pme_commit = subprocess.check_output( | |
["git", "log", "--oneline", "-1"], cwd=cwd | |
).decode("utf-8") | |
torch_pme_status = subprocess.check_output( | |
["git", "status", "--porcelain"], cwd=cwd | |
).decode("utf-8") | |
except subprocess.CalledProcessError: | |
torch_pme_commit = "not found" | |
torch_pme_status = "not found" | |
version = { | |
"torch": str(torch.__version__), | |
"torch-pme-commit": torch_pme_commit, | |
"torch-pme-status": torch_pme_status, | |
"torch-pme-version": torchpme.__version__, | |
} | |
version = { | |
"torch": str(torch.__version__), | |
"torch-pme": str(torchpme.__version__), | |
} |
def compute_distances(positions, neighbor_indices, cell=None, neighbor_shifts=None): | ||
"""Compute pairwise distances.""" | ||
atom_is = neighbor_indices[:, 0] | ||
atom_js = neighbor_indices[:, 1] | ||
|
||
pos_is = positions[atom_is] | ||
pos_js = positions[atom_js] | ||
|
||
distance_vectors = pos_js - pos_is | ||
|
||
if cell is not None and neighbor_shifts is not None: | ||
shifts = neighbor_shifts.type(cell.dtype) | ||
distance_vectors += shifts @ cell | ||
elif cell is not None and neighbor_shifts is None: | ||
raise ValueError("Provided `cell` but no `neighbor_shifts`.") | ||
elif cell is None and neighbor_shifts is not None: | ||
raise ValueError("Provided `neighbor_shifts` but no `cell`.") | ||
|
||
return torch.linalg.norm(distance_vectors, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason not to use external code or reusing the function we already defined in the testsuite? I am not a big fan of having duplicated code in the repo...
I am even okay moving the distance function as private function in the main repo, if you are afraid of abusing sys.path.append
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment that we run only single precision (because we showed in the paper that it makes no difference!)
return (system, neighbors) | ||
|
||
|
||
def get_calculate_fn(calculator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we using JAX?
Fixes #17
A continuation of #81 -- see discussion there.
This adds a benchmark script to catch performance regressions.
Contributor (creator of pull-request) checklist
Reviewer checklist
📚 Documentation preview 📚: https://torch-pme--173.org.readthedocs.build/en/173/