Skip to content

Commit

Permalink
Merge pull request #128 from lab-cosmo/exponents
Browse files Browse the repository at this point in the history
Add general integer exponents
  • Loading branch information
E-Rum authored Jan 17, 2025
2 parents e5b1ce8 + 0f8d7e5 commit 3540a81
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 51 deletions.
2 changes: 2 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Added
Fixed
#####

* Refactor the ``InversePowerLawPotential`` class to restrict the exponent to integer
values
* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and
``Calculator`` classses
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
Expand Down
4 changes: 2 additions & 2 deletions examples/8-combined-potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.

pot_1 = InversePowerLawPotential(exponent=1.0, smearing=smearing)
pot_2 = InversePowerLawPotential(exponent=2.0, smearing=smearing)
pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)

potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)

Expand Down
1 change: 0 additions & 1 deletion src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,5 @@ def _compute_kspace(
charge_tot = torch.sum(charges, dim=0)
prefac = self.potential.background_correction()
energy -= 2 * prefac * charge_tot * ivolume

# Compensate for double counting of pairs (i,j) and (j,i)
return energy / 2
5 changes: 5 additions & 0 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
generate_kvectors_for_mesh,
get_ns_mesh,
)
from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1
from .mesh_interpolator import MeshInterpolator

__all__ = [
Expand All @@ -16,4 +17,8 @@
"generate_kvectors_for_ewald",
"generate_kvectors_for_mesh",
"get_ns_mesh",
"gamma",
"CustomExp1",
"gammaincc_over_powerlaw",
"torch_exp1",
]
56 changes: 56 additions & 0 deletions src/torchpme/lib/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from scipy.special import exp1
from torch.special import gammaln


def gamma(x: torch.Tensor) -> torch.Tensor:
"""
(Complete) Gamma function.
pytorch has not implemented the commonly used (complete) Gamma function. We define
it in a custom way to make autograd work as in
https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
"""
return torch.exp(gammaln(x))


class CustomExp1(torch.autograd.Function):
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""

@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
return -grad_output * torch.exp(-input) / input


def torch_exp1(input):
"""Wrapper for the custom exponential integral function."""
return CustomExp1.apply(input)


def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
if exponent == 1:
return torch.exp(-z) / z
if exponent == 2:
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
if exponent == 3:
return torch_exp1(z)
if exponent == 4:
return 2 * (
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
)
if exponent == 5:
return torch.exp(-z) - z * torch_exp1(z)
if exponent == 6:
return (
(2 - 4 * z) * torch.exp(-z)
+ 4 * torch.sqrt(torch.pi * z**3) * torch.erfc(torch.sqrt(z))
) / 3
raise ValueError(f"Unsupported exponent: {exponent}")
31 changes: 12 additions & 19 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
from typing import Optional

import torch
from torch.special import gammainc, gammaincc, gammaln
from torch.special import gammainc

from .potential import Potential


def gamma(x: torch.Tensor) -> torch.Tensor:
"""
(Complete) Gamma function.
from torchpme.lib import gamma, gammaincc_over_powerlaw

pytorch has not implemented the commonly used (complete) Gamma function. We define
it in a custom way to make autograd work as in
https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
"""
return torch.exp(gammaln(x))
from .potential import Potential


class InversePowerLawPotential(Potential):
Expand Down Expand Up @@ -46,16 +37,16 @@ class InversePowerLawPotential(Potential):

def __init__(
self,
exponent: float,
exponent: int,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)

if exponent <= 0 or exponent > 3:
raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3")
# function call to check the validity of the exponent
gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device))
self.register_buffer(
"exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
)
Expand Down Expand Up @@ -130,9 +121,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
# for consistency reasons.
masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb
return torch.where(
k_sq == 0,
0.0,
prefac * gammaincc(peff, masked) / masked**peff * gamma(peff),
k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked)
)

def self_contribution(self) -> torch.Tensor:
Expand All @@ -145,7 +134,11 @@ def self_contribution(self) -> torch.Tensor:
return 1 / gamma(phalf + 1) / (2 * self.smearing**2) ** phalf

def background_correction(self) -> torch.Tensor:
# "charge neutrality" correction for 1/r^p potential
# "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
# and is not needed for p > 3 , so we set it to zero (see in
# https://doi.org/10.48550/arXiv.2412.03281 SI section)
if self.exponent >= 3:
return torch.tensor(0.0, dtype=self.dtype, device=self.device)
if self.smearing is None:
raise ValueError(
"Cannot compute background correction without specifying `smearing`."
Expand Down
6 changes: 3 additions & 3 deletions tests/calculators/test_values_ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
lr_wavelength = 0.5 * smearing
calc = EwaldCalculator(
InversePowerLawPotential(
exponent=1.0,
exponent=1,
smearing=smearing,
),
lr_wavelength=lr_wavelength,
Expand All @@ -111,7 +111,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
smearing = sr_cutoff / 5.0
calc = PMECalculator(
InversePowerLawPotential(
exponent=1.0,
exponent=1,
smearing=smearing,
),
mesh_spacing=smearing / 8,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_wigner(crystal_name, scaling_factor):
# Compute potential and compare against reference
calc = EwaldCalculator(
InversePowerLawPotential(
exponent=1.0,
exponent=1,
smearing=smeareff,
),
lr_wavelength=smeareff / 2,
Expand Down
25 changes: 25 additions & 0 deletions tests/lib/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import torch
from scipy.special import exp1

from torchpme.lib import torch_exp1


def finite_difference_derivative(func, x, h=1e-5):
return (func(x + h) - func(x - h)) / (2 * h)


def test_torch_exp1_consistency_with_scipy():
x = torch.rand(1000, dtype=torch.float64)
torch_result = torch_exp1(x)
scipy_result = exp1(x.numpy())
assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)


def test_torch_exp1_derivative():
x = torch.rand(1, dtype=torch.float64, requires_grad=True)
torch_result = torch_exp1(x)
torch_result.backward()
torch_exp1_prime = x.grad
finite_diff_result = finite_difference_derivative(exp1, x.detach().numpy())
assert np.allclose(torch_exp1_prime.numpy(), finite_diff_result, atol=1e-6)
Loading

0 comments on commit 3540a81

Please sign in to comment.