Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.jit.script #15

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .flexci/build_and_push.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ docker_build_and_push() {
WAIT_PIDS=""

# PyTorch 1.5 + Python 3.6
docker_build_and_push torch15 \
docker_build_and_push torch19 \
--build-arg base_image="nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04" \
--build-arg python_version="3.6.12" \
--build-arg pip_packages="torch==1.5.* torchvision==0.6.* ${TEST_PIP_PACKAGES}" &
--build-arg python_version="3.7.9" \
--build-arg pip_packages="torch==1.9.1 ${TEST_PIP_PACKAGES}" &
WAIT_PIDS="$! ${WAIT_PIDS}"

# Wait until the build complete.
Expand Down
4 changes: 2 additions & 2 deletions .flexci/pytest_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -eu

#IMAGE=pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel
IMAGE=asia.gcr.io/pfn-public-ci/torch-dftd-ci:torch15
IMAGE=asia.gcr.io/pfn-public-ci/torch-dftd-ci:torch19


main() {
Expand All @@ -20,7 +20,7 @@ main() {
docker run --runtime=nvidia --rm --volume="$(pwd)":/workspace -w /workspace \
${IMAGE} \
bash -x -c "pip install flake8 pytest pytest-cov pytest-xdist pytest-benchmark && \
pip install cupy-cuda102 pytorch-pfn-extras!=0.5.0 && \
pip install cupy-cuda102 pytorch-pfn-extras==0.4.2 && \
pip install -e .[develop] && \
pysen run lint && \
pytest --cov=torch_dftd -n $(nproc) -m 'not slow' tests &&
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ print(f"forces {forces}")
## Dependency

The library is tested under following environment.
- python: 3.6
- python: 3.7
- CUDA: 10.2
```bash
torch==1.5.1
torch==1.9.1
ase==3.21.1
# Below is only for 3-body term
cupy-cuda102==8.6.0
pytorch-pfn-extras==0.3.2
cupy-cuda102==9.5.0
pytorch-pfn-extras==0.4.2
```

## Development tips
Expand Down
2 changes: 1 addition & 1 deletion torch_dftd/dftd3_xc_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def get_dftd3_default_params(
rs6 = 1.1
s18 = 0.0
alp = 20.0
rs18 = None # Not used.
rs18 = 0.0 # It is DUMMY value. Not used.
if xc == "b-lyp":
s6 = 1.2
elif xc == "b-p":
Expand Down
4 changes: 3 additions & 1 deletion torch_dftd/functions/dftd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_dftd.functions.smoothing import poly_smoothing


@torch.jit.script
def edisp_d2(
Z: Tensor,
r: Tensor,
Expand Down Expand Up @@ -44,7 +45,7 @@ def edisp_d2(
r2 = r ** 2
r6 = r2 ** 3

idx_i, idx_j = edge_index
idx_i, idx_j = edge_index[0], edge_index[1]
# compute all necessary quantities
Zi = Z[idx_i] # (n_edges,)
Zj = Z[idx_j]
Expand All @@ -71,6 +72,7 @@ def edisp_d2(
# (1,)
g = e6.sum()[None]
else:
assert batch is not None
# (n_graphs,)
if batch.size()[0] == 0:
n_graphs = 1
Expand Down
40 changes: 20 additions & 20 deletions torch_dftd/functions/dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@

# conversion factors used in grimme d3 code

d3_autoang = 0.52917726 # for converting distance from bohr to angstrom
d3_autoev = 27.21138505 # for converting a.u. to eV
d3_autoang: float = 0.52917726 # for converting distance from bohr to angstrom
d3_autoev: float = 27.21138505 # for converting a.u. to eV

d3_k1 = 16.000
d3_k2 = 4 / 3
d3_k3 = -4.000
d3_maxc = 5 # maximum number of coordination complexes


@torch.jit.script
def _ncoord(
Z: Tensor,
r: Tensor,
Expand Down Expand Up @@ -53,7 +54,7 @@ def _ncoord(
Zi = Z[idx_i]
Zj = Z[idx_j]
rco = rcov[Zi] + rcov[Zj] # (n_edges,)
rr = rco.type(r.dtype) / r
rr = rco.to(r.dtype) / r
damp = 1.0 / (1.0 + torch.exp(-k1 * (rr - 1.0)))
if cutoff is not None and cutoff_smoothing == "poly":
damp *= poly_smoothing(r, cutoff)
Expand All @@ -66,6 +67,7 @@ def _ncoord(
return g # (n_atoms,)


@torch.jit.script
def _getc6(
Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3
) -> Tensor:
Expand All @@ -84,7 +86,7 @@ def _getc6(
"""
# gather the relevant entries from the table
# c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3)
c6ab_ = c6ab[Zi, Zj].type(nci.dtype)
c6ab_ = c6ab[Zi, Zj].to(nci.dtype)
# calculate c6 coefficients

# cn0, cn1, cn2 (n_edges, 5, 5)
Expand All @@ -104,6 +106,7 @@ def _getc6(
return c6


@torch.jit.script
def edisp(
Z: Tensor,
r: Tensor,
Expand All @@ -120,12 +123,12 @@ def edisp(
shift_pos: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
cell: Optional[Tensor] = None,
r2=None,
r6=None,
r8=None,
k1=d3_k1,
k2=d3_k2,
k3=d3_k3,
r2: Optional[Tensor] = None,
r6: Optional[Tensor] = None,
r8: Optional[Tensor] = None,
k1: float = d3_k1,
k2: float = d3_k2,
k3: float = d3_k3,
cutoff_smoothing: str = "none",
damping: str = "zero",
bidirectional: bool = False,
Expand All @@ -146,7 +149,7 @@ def edisp(
cnthr (float or None): cutoff distance for coordination number calculation in **bohr**
batch (Tensor or None): (n_atoms,)
batch_edge (Tensor or None): (n_edges,)
shift_pos (Tensor or None): (n_atoms,) used to calculate 3-body term when abc=True
shift_pos (Tensor or None): (n_edges, 3) used to calculate 3-body term when abc=True
pos (Tensor): (n_atoms, 3) position in **bohr**
cell (Tensor): (3, 3) cell size in **bohr**
r2 (Tensor or None):
Expand All @@ -171,7 +174,7 @@ def edisp(
if r8 is None:
r8 = r6 * r2

idx_i, idx_j = edge_index
idx_i, idx_j = edge_index[0], edge_index[1]
# compute all necessary quantities
Zi = Z[idx_i] # (n_edges,)
Zj = Z[idx_j]
Expand All @@ -192,7 +195,7 @@ def edisp(
ncj = nc[idx_j]
c6 = _getc6(Zi, Zj, nci, ncj, c6ab=c6ab, k3=k3) # c6 coefficients

c8 = 3 * c6 * r2r4[Zi].type(c6.dtype) * r2r4[Zj].type(c6.dtype) # c8 coefficient
c8 = 3 * c6 * r2r4[Zi].to(c6.dtype) * r2r4[Zj].to(c6.dtype) # c8 coefficient

s6 = params["s6"]
s8 = params["s18"]
Expand Down Expand Up @@ -250,6 +253,7 @@ def edisp(
g = e68.to(torch.float64).sum()[None]
else:
# (n_graphs,)
assert batch is not None
if batch.size()[0] == 0:
n_graphs = 1
else:
Expand All @@ -261,6 +265,7 @@ def edisp(
g *= 2.0

if abc:
assert cnthr is not None
within_cutoff = r <= cnthr
# r_abc = r[within_cutoff]
# r2_abc = r2[within_cutoff]
Expand All @@ -282,12 +287,7 @@ def edisp(
# (n_edges, ) -> (n_edges * 2, )
shift_abc = None if shift_abc is None else torch.cat([shift_abc, -shift_abc], dim=0)
with torch.no_grad():
# triplet_node_index, triplet_edge_index = calc_triplets_cycle(edge_index_abc, n_atoms, shift=shift_abc)
# Type hinting
triplet_node_index: Tensor
multiplicity: Tensor
edge_jk: Tensor
batch_triplets: Optional[Tensor]
assert pos is not None
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
edge_index_abc,
shift_pos=shift_abc,
Expand All @@ -303,7 +303,6 @@ def edisp(
)
r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, shift_jk)
kj_within_cutoff = r_jk <= cnthr
del shift_jk

triplet_node_index = triplet_node_index[kj_within_cutoff]
multiplicity, edge_jk, batch_triplets = (
Expand Down Expand Up @@ -355,5 +354,6 @@ def edisp(
e6abc = e3.to(torch.float64).sum()
g += e6abc
else:
assert batch_triplets is not None
g.scatter_add_(0, batch_triplets, e3.to(torch.float64))
return g # (n_graphs,)
8 changes: 5 additions & 3 deletions torch_dftd/functions/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
from torch import Tensor


@torch.jit.script
def calc_distances(
pos: Tensor,
edge_index: Tensor,
cell: Optional[Tensor] = None,
shift_pos: Optional[Tensor] = None,
eps=1e-20,
eps: float = 1e-20,
) -> Tensor:
"""Distance calculation function.

Args:
pos (Tensor): (n_atoms, 3) atom positions.
edge_index (Tensor): (2, n_edges) edge_index for graph.
cell (Tensor): cell size, None for non periodic system.
This it NOT USED now, it is left for backward compatibility.
shift_pos (Tensor): (n_edges, 3) position shift vectors of edges owing to the periodic boundary. It should be length unit.
eps (float): Small float value to avoid NaN in backward when the distance is 0.

Expand All @@ -25,11 +27,11 @@ def calc_distances(

"""

idx_i, idx_j = edge_index
idx_i, idx_j = edge_index[0], edge_index[1]
# calculate interatomic distances
Ri = pos[idx_i]
Rj = pos[idx_j]
if cell is not None:
if shift_pos is not None:
Rj += shift_pos
# eps is to avoid Nan in backward when Dij = 0 with sqrt.
Dij = torch.sqrt(torch.sum((Ri - Rj) ** 2, dim=-1) + eps)
Expand Down
3 changes: 2 additions & 1 deletion torch_dftd/functions/smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from torch import Tensor


@torch.jit.script
def poly_smoothing(r: Tensor, cutoff: float) -> Tensor:
"""Computes a smooth step from 1 to 0 starting at 1 bohr before the cutoff

Args:
r (Tensor): (n_edges,)
cutoff (float): ()
cutoff (float): cutoff length

Returns:
r (Tensor): Smoothed `r`
Expand Down
Loading