Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching destroys computational graph #62

Closed
PicoCentauri opened this issue Sep 25, 2024 · 6 comments · Fixed by #64
Closed

Caching destroys computational graph #62

PicoCentauri opened this issue Sep 25, 2024 · 6 comments · Fixed by #64

Comments

@PicoCentauri
Copy link
Contributor

The recently introduced caching in #43 does not make it possible to use backward a second in a PMEPotential instance when the cell is not changing. The first time one calls backward on a calculator works fine

import torch
from torchpme import PMEPotential

sr_cutoff = 5.5

calculator = PMEPotential(
    exponent=1.0,
    atomic_smearing=sr_cutoff / 5.0,
    mesh_spacing=None,
    interpolation_order=3,
    subtract_interior=False,
)

positions = 5 * torch.rand((100, 3), requires_grad=True)
cell = 5 * torch.eye(3, requires_grad=True)
charges = torch.rand((100, 1))
neighbor_indices = torch.zeros((0, 2), dtype=torch.int64)
neighbor_distances = torch.zeros((0))

output = calculator(positions, charges, cell, neighbor_indices, neighbor_distances)
output = output.sum()
output.backward()

But, if you reuse the calculator and compute the potential a second time

positions = 5 * torch.rand((100, 3), requires_grad=True)
cell = 5 * torch.eye(3, requires_grad=True)
charges = torch.rand((100, 1))
neighbor_indices = torch.zeros((0, 2), dtype=torch.int64)
neighbor_distances = torch.zeros((0))

output = calculator(positions, charges, cell, neighbor_indices, neighbor_distances)
output = output.sum()
output.backward()

You get a RuntimeError stating that

Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

This is introduced by the caching because if one changes the cell to something like cell = 4 * torch.eye(3, requires_grad=True). The backward() call works without problems.

I am not sure but this problem might also be related to solving #57.

Thanks to @Luthaf who tracked this down

@sirmarcel
Copy link
Contributor

What happens if you set retain_graph=True ?

@Luthaf
Copy link
Contributor

Luthaf commented Sep 25, 2024

this might make the error go away, but might also give wrong results, since the new cell is not part of the computational graph and only the old one is

@sirmarcel
Copy link
Contributor

Yes, I wasn't thinking clearly. It seems that some more plumbing is required to make torch understand that the gradients have to be "forwarded" to the new instance. No idea how to do that, but you're the torch-🧙!

@PicoCentauri
Copy link
Contributor Author

Yes .backward(retain_graph=True) fixes the error. If the new cell is not in the graph it might not be a problem. If the cell dimensions are the same the values we get are the same...

I have to check if the forces are really correct though. But, if we are the code inside a metatensor atomistic model we are not doing the gradients by ourselves and can't keep the graph. And additionally, I think it is a bad experience if you get faced with an error when you try to call backward.

I would argue to remove the caching or make it optional. What was the speedup you got @ceriottm from the caching? We didn't report it in the PR.

If we make it optional, I would turn it off by default and we have to state that one has to keep if one calls forward a second time. HOWEVER, we can only do this if the if the forces are correct which is something to be tested.

@PicoCentauri
Copy link
Contributor Author

I tried torch.autograd.gradcheck to check if the gradients are correct but I am not sure if the functions keeps the graph alive.

Better also for the tests would be to have some structures with the same cell but different positions and consequently forces.

@PicoCentauri
Copy link
Contributor Author

retain_graph=True seems to work and the error also only appears if one requires also gradients on the cell.

I vote for making it optional other opinions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants