Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator
from ase.calculators.calculator import all_properties as ase_all_properties
from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress
from ase.neighborlist import NewPrimitiveNeighborList
from ase.data import chemical_symbols
import itertools

import torch
import e3nn.o3
Expand Down Expand Up @@ -971,17 +974,40 @@ def neighbor_list_and_relative_vec(

# ASE dependent part
temp_cell = ase.geometry.complete_cell(temp_cell)

first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
"ijS",
pbc,
temp_cell,
temp_pos,
cutoff=float(_r_max),
self_interaction=self_interaction, # we want edges from atom to itself in different periodic images!
use_scaled_positions=False,
##################################################################################
###################################### 新代码 ########################################
elements = np.unique(atomic_numbers).tolist()
pair_cutoffs = {}
for elem1, elem2 in itertools.combinations_with_replacement(elements, 2):
pair_cutoffs[(elem1, elem2)] = max(r_max[chemical_symbols[elem1]], r_max[chemical_symbols[elem2]])

nl = NewPrimitiveNeighborList(
cutoffs=10,
skin=0.0,
self_interaction=self_interaction,
bothways=True,
use_scaled_positions=False
)

nl.cutoffs = pair_cutoffs
nl.update(pbc, temp_cell, temp_pos, atomic_numbers)
first_idex, second_idex, shifts = nl.pair_first, nl.pair_second, nl.offset_vec
mask_r = False
_r_max = max(pair_cutoffs.values())

##################################################################################
##################################### 老代码 #########################################
# first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
# "ijS",
# pbc,
# temp_cell,
# temp_pos,
# cutoff=float(_r_max),
# self_interaction=self_interaction, # we want edges from atom to itself in different periodic images!
# use_scaled_positions=False,
# )

##################################################################################
##################################################################################
# Eliminate true self-edges that don't cross periodic boundaries
# if not self_interaction:
# bad_edge = first_idex == second_idex
Expand Down
11 changes: 7 additions & 4 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
transform: bool = True,
scale_type: str = 'scale_w_back_grad',
**kwargs,
):

Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
self.device = device
self.model_options = {"embedding": embedding.copy(), "prediction": prediction.copy()}
self.transform = transform

self.scale_type = scale_type

self.method = prediction.get("method", "e3tb")
# self.soc = prediction.get("soc", False)
Expand Down Expand Up @@ -298,9 +299,11 @@ def forward(self, data: AtomicDataDict.Type):
data = self.embedding(data)
if hasattr(self, "overlap") and self.method == "sktb":
data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY]

data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

if self.scale_type != 'no_scale':
data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

if hasattr(self, "overlap"):
data = self.edge_prediction_s(data)
data[AtomicDataDict.NODE_OVERLAP_KEY] = self.overlaponsite_param[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
Expand Down
20 changes: 14 additions & 6 deletions dptb/nn/embedding/lem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
avg_num_neighbors: Optional[float] = None,
# cutoffs
r_start_cos_ratio: float = 0.8,
norm_eps: float = 1e-8,
PolynomialCutoff_p: float = 6,
cutoff_type: str = "polynomial",
# general hyperparameters:
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
cutoff_type=cutoff_type,
device=device,
dtype=dtype,
norm_eps=norm_eps
)

self.layers = torch.nn.ModuleList()
Expand Down Expand Up @@ -235,6 +237,7 @@ def __init__(
latent_dim: int=128,
# cutoffs
r_start_cos_ratio: float = 0.8,
norm_eps: float = 1e-8,
PolynomialCutoff_p: float = 6,
cutoff_type: str = "polynomial",
device: Union[str, torch.device] = torch.device("cpu"),
Expand Down Expand Up @@ -290,7 +293,7 @@ def __init__(

self.sln_n = SeperableLayerNorm(
irreps=self.irreps_out,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -300,7 +303,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.irreps_out,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -438,6 +441,7 @@ def __init__(
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
latent_dim: int,
norm_eps: float = 1e-8,
radial_emb: bool=False,
radial_channels: list=[128, 128],
res_update: bool = True,
Expand Down Expand Up @@ -470,7 +474,7 @@ def __init__(

self.sln = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -480,7 +484,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.edge_irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -614,6 +618,7 @@ def __init__(
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
latent_dim: int,
norm_eps: float = 1e-8,
latent_channels: list=[128, 128],
radial_emb: bool=False,
radial_channels: list=[128, 128],
Expand Down Expand Up @@ -675,7 +680,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -685,7 +690,7 @@ def __init__(

self.sln_n = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -806,6 +811,7 @@ def __init__(
tp_radial_emb: bool=False,
tp_radial_channels: list=[128, 128],
# MLP parameters:
norm_eps: float = 1e-8,
latent_channels: list=[128, 128],
latent_dim: int=128,
res_update: bool = True,
Expand Down Expand Up @@ -842,6 +848,7 @@ def __init__(
res_update_ratios_learnable=res_update_ratios_learnable,
dtype=dtype,
device=device,
norm_eps=norm_eps
)

self.node_update = UpdateNode(
Expand All @@ -857,6 +864,7 @@ def __init__(
avg_num_neighbors=avg_num_neighbors,
dtype=dtype,
device=device,
norm_eps=norm_eps
)

def forward(self, latents, node_features, edge_features, node_onehot, edge_index, edge_vector, atom_type, cutoff_coeffs, active_edges):
Expand Down
35 changes: 29 additions & 6 deletions dptb/nn/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(
shifts_trainable: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
scale_type: str = 'scale_w_back_grad',
**kwargs,
):
"""Sum edges into nodes."""
Expand All @@ -233,6 +234,8 @@ def __init__(
self.dtype = dtype
self.shift_index = []
self.scale_index = []
self.scale_type = scale_type
self.scales_trainable = scales_trainable

start = 0
start_scalar = 0
Expand Down Expand Up @@ -293,7 +296,6 @@ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None):
self.register_buffer("shifts", shifts)



def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

if not (self.has_scales or self.has_shifts):
Expand All @@ -305,22 +307,31 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
in_field = data[self.field][mask]
species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()[mask]



assert len(in_field) == len(
edge_center[mask]
), "in_field doesnt seem to have correct per-edge shape"

if self.has_scales:
in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim)
if self.scale_type == 'scale_w_back_grad':
in_field = scales * in_field
elif self.scale_type == 'scale_wo_back_grad':
if self.scales_trainable:
in_field = in_field + in_field.detach() * (scales - 1.0)
else:
in_field = in_field + (in_field * (scales - 1.0)).detach()
else:
raise NotImplementedError

if self.has_shifts:
shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]

data[self.out_field][mask] = in_field

return data


class E3PerSpeciesScaleShift(torch.nn.Module):
"""Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters.

Expand Down Expand Up @@ -358,6 +369,7 @@ def __init__(
shifts_trainable: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
scale_type: str = 'scale_w_back_grad',
**kwargs,
):
super().__init__()
Expand All @@ -370,6 +382,8 @@ def __init__(
self.scale_index = []
self.dtype = dtype
self.device = device
self.scale_type = scale_type
self.scales_trainable = scales_trainable

start = 0
start_scalar = 0
Expand Down Expand Up @@ -442,7 +456,16 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
species_idx
), "in_field doesnt seem to have correct per-atom shape"
if self.has_scales:
in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim)
if self.scale_type == 'scale_w_back_grad':
in_field = scales * in_field
elif self.scale_type == 'scale_wo_back_grad':
if self.scales_trainable:
in_field = in_field + in_field.detach() * (scales - 1.0)
else:
in_field = in_field + (in_field * (scales - 1.0)).detach()
else:
raise NotImplementedError
if self.has_shifts:
shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]
Expand Down
Loading
Loading