Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculators fail to torch.jit.load #178

Open
frostedoyster opened this issue Feb 24, 2025 · 1 comment
Open

Calculators fail to torch.jit.load #178

frostedoyster opened this issue Feb 24, 2025 · 1 comment
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@frostedoyster
Copy link

frostedoyster commented Feb 24, 2025

For example, this fails:

from torchpme.calculators import EwaldCalculator
from torchpme import InversePowerLawPotential
import torch


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.calculator = EwaldCalculator(
            potential=InversePowerLawPotential(
                exponent=1,
                smearing=1.0,
                exclusion_radius=5.0,
            ),
            full_neighbor_list=True,
            lr_wavelength=5.0,
            prefactor=1.0,
        )

    def forward(self, x):
        return self.calculator(x, x, x, x, x)

tensor = torch.randn(10, 3)

model = Model()
model = torch.jit.script(model)
torch.jit.save(model, "model.pt")
loaded_model = torch.jit.load("model.pt")

Sorry for the unrealistic example, but the same also happens when feeding appropriate tensors to the model.

The error:

Traceback (most recent call last):
  File ".../program.py", line 28, in <module>
    torch.jit.save(model, "model.pt")
  File ".../lib/python3.12/site-packages/torch/jit/_serialization.py", line 84, in save
    m.save(f, _extra_files=_extra_files)
  File ".../lib/python3.12/site-packages/torch/jit/_script.py", line 754, in save
    return self._c.save(str(f), **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: 
Could not export Python function call '_CustomExp1'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
  File ".../lib/python3.12/site-packages/torchpme/lib/math.py", line 76
    :return: Exponential integral E1(x)
    """
    return _CustomExp1.apply(x)
           ~~~~~~~~~~~~~~~~~ <--- HERE
@E-Rum
Copy link
Contributor

E-Rum commented Feb 24, 2025

Good catch, @frostedoyster, thank you very much! After a bit of research, I found that, unfortunately, torch.script does not support custom autograd functions. We need to figure out how to replace it in this case.

@PicoCentauri PicoCentauri added bug Something isn't working help wanted Extra attention is needed labels Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants