diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py index 46b5dd37..a90a85bc 100644 --- a/dptb/data/AtomicData.py +++ b/dptb/data/AtomicData.py @@ -15,6 +15,9 @@ from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress +from ase.neighborlist import NewPrimitiveNeighborList +from ase.data import chemical_symbols +import itertools import torch import e3nn.o3 @@ -971,17 +974,40 @@ def neighbor_list_and_relative_vec( # ASE dependent part temp_cell = ase.geometry.complete_cell(temp_cell) - - first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( - "ijS", - pbc, - temp_cell, - temp_pos, - cutoff=float(_r_max), - self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! - use_scaled_positions=False, +################################################################################## +###################################### 新代码 ######################################## + elements = np.unique(atomic_numbers).tolist() + pair_cutoffs = {} + for elem1, elem2 in itertools.combinations_with_replacement(elements, 2): + pair_cutoffs[(elem1, elem2)] = max(r_max[chemical_symbols[elem1]], r_max[chemical_symbols[elem2]]) + + nl = NewPrimitiveNeighborList( + cutoffs=10, + skin=0.0, + self_interaction=self_interaction, + bothways=True, + use_scaled_positions=False ) - + nl.cutoffs = pair_cutoffs + nl.update(pbc, temp_cell, temp_pos, atomic_numbers) + first_idex, second_idex, shifts = nl.pair_first, nl.pair_second, nl.offset_vec + mask_r = False + _r_max = max(pair_cutoffs.values()) + + ################################################################################## + ##################################### 老代码 ######################################### + # first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( + # "ijS", + # pbc, + # temp_cell, + # temp_pos, + # cutoff=float(_r_max), + # self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! + # use_scaled_positions=False, + # ) + + ################################################################################## + ################################################################################## # Eliminate true self-edges that don't cross periodic boundaries # if not self_interaction: # bad_edge = first_idex == second_idex diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 13fbf984..7bdd481f 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -68,6 +68,7 @@ def __init__( dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), transform: bool = True, + scale_type: str = 'scale_w_back_grad', **kwargs, ): @@ -103,7 +104,7 @@ def __init__( self.device = device self.model_options = {"embedding": embedding.copy(), "prediction": prediction.copy()} self.transform = transform - + self.scale_type = scale_type self.method = prediction.get("method", "e3tb") # self.soc = prediction.get("soc", False) @@ -298,9 +299,11 @@ def forward(self, data: AtomicDataDict.Type): data = self.embedding(data) if hasattr(self, "overlap") and self.method == "sktb": data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY] - - data = self.node_prediction_h(data) - data = self.edge_prediction_h(data) + + if self.scale_type != 'no_scale': + data = self.node_prediction_h(data) + data = self.edge_prediction_h(data) + if hasattr(self, "overlap"): data = self.edge_prediction_s(data) data[AtomicDataDict.NODE_OVERLAP_KEY] = self.overlaponsite_param[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] diff --git a/dptb/nn/embedding/lem.py b/dptb/nn/embedding/lem.py index 2987e71c..abdc1b06 100644 --- a/dptb/nn/embedding/lem.py +++ b/dptb/nn/embedding/lem.py @@ -39,6 +39,7 @@ def __init__( avg_num_neighbors: Optional[float] = None, # cutoffs r_start_cos_ratio: float = 0.8, + norm_eps: float = 1e-8, PolynomialCutoff_p: float = 6, cutoff_type: str = "polynomial", # general hyperparameters: @@ -131,6 +132,7 @@ def __init__( cutoff_type=cutoff_type, device=device, dtype=dtype, + norm_eps=norm_eps ) self.layers = torch.nn.ModuleList() @@ -235,6 +237,7 @@ def __init__( latent_dim: int=128, # cutoffs r_start_cos_ratio: float = 0.8, + norm_eps: float = 1e-8, PolynomialCutoff_p: float = 6, cutoff_type: str = "polynomial", device: Union[str, torch.device] = torch.device("cpu"), @@ -290,7 +293,7 @@ def __init__( self.sln_n = SeperableLayerNorm( irreps=self.irreps_out, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -300,7 +303,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.irreps_out, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -438,6 +441,7 @@ def __init__( irreps_in: o3.Irreps, irreps_out: o3.Irreps, latent_dim: int, + norm_eps: float = 1e-8, radial_emb: bool=False, radial_channels: list=[128, 128], res_update: bool = True, @@ -470,7 +474,7 @@ def __init__( self.sln = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -480,7 +484,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.edge_irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -614,6 +618,7 @@ def __init__( irreps_in: o3.Irreps, irreps_out: o3.Irreps, latent_dim: int, + norm_eps: float = 1e-8, latent_channels: list=[128, 128], radial_emb: bool=False, radial_channels: list=[128, 128], @@ -675,7 +680,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -685,7 +690,7 @@ def __init__( self.sln_n = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -806,6 +811,7 @@ def __init__( tp_radial_emb: bool=False, tp_radial_channels: list=[128, 128], # MLP parameters: + norm_eps: float = 1e-8, latent_channels: list=[128, 128], latent_dim: int=128, res_update: bool = True, @@ -842,6 +848,7 @@ def __init__( res_update_ratios_learnable=res_update_ratios_learnable, dtype=dtype, device=device, + norm_eps=norm_eps ) self.node_update = UpdateNode( @@ -857,6 +864,7 @@ def __init__( avg_num_neighbors=avg_num_neighbors, dtype=dtype, device=device, + norm_eps=norm_eps ) def forward(self, latents, node_features, edge_features, node_onehot, edge_index, edge_vector, atom_type, cutoff_coeffs, active_edges): diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py index 8aca26aa..8d935140 100644 --- a/dptb/nn/rescale.py +++ b/dptb/nn/rescale.py @@ -220,6 +220,7 @@ def __init__( shifts_trainable: bool = False, dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), + scale_type: str = 'scale_w_back_grad', **kwargs, ): """Sum edges into nodes.""" @@ -233,6 +234,8 @@ def __init__( self.dtype = dtype self.shift_index = [] self.scale_index = [] + self.scale_type = scale_type + self.scales_trainable = scales_trainable start = 0 start_scalar = 0 @@ -293,7 +296,6 @@ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): self.register_buffer("shifts", shifts) - def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if not (self.has_scales or self.has_shifts): @@ -305,22 +307,31 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: in_field = data[self.field][mask] species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()[mask] - - assert len(in_field) == len( edge_center[mask] ), "in_field doesnt seem to have correct per-edge shape" if self.has_scales: - in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim) + if self.scale_type == 'scale_w_back_grad': + in_field = scales * in_field + elif self.scale_type == 'scale_wo_back_grad': + if self.scales_trainable: + in_field = in_field + in_field.detach() * (scales - 1.0) + else: + in_field = in_field + (in_field * (scales - 1.0)).detach() + else: + raise NotImplementedError + if self.has_shifts: shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] - + data[self.out_field][mask] = in_field return data + class E3PerSpeciesScaleShift(torch.nn.Module): """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. @@ -358,6 +369,7 @@ def __init__( shifts_trainable: bool = False, dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), + scale_type: str = 'scale_w_back_grad', **kwargs, ): super().__init__() @@ -370,6 +382,8 @@ def __init__( self.scale_index = [] self.dtype = dtype self.device = device + self.scale_type = scale_type + self.scales_trainable = scales_trainable start = 0 start_scalar = 0 @@ -442,7 +456,16 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: species_idx ), "in_field doesnt seem to have correct per-atom shape" if self.has_scales: - in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim) + if self.scale_type == 'scale_w_back_grad': + in_field = scales * in_field + elif self.scale_type == 'scale_wo_back_grad': + if self.scales_trainable: + in_field = in_field + in_field.detach() * (scales - 1.0) + else: + in_field = in_field + (in_field * (scales - 1.0)).detach() + else: + raise NotImplementedError if self.has_shifts: shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py index 4e8d6fcd..055987ee 100644 --- a/dptb/nnops/loss.py +++ b/dptb/nnops/loss.py @@ -643,6 +643,95 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): return 0.5 * (onsite_loss + hopping_loss) +@Loss.register("hamil_abs_mae") +class HamilLossAbsMAE(nn.Module): + def __init__( + self, + basis: Dict[str, Union[str, list]] = None, + idp: Union[OrbitalMapper, None] = None, + overlap: bool = False, + onsite_shift: bool = False, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu"), + **kwargs, + ): + + super(HamilLossAbsMAE, self).__init__() + self.loss1 = nn.L1Loss() + self.loss2 = nn.MSELoss() + self.overlap = overlap + self.device = device + self.onsite_shift = onsite_shift + + if basis is not None: + self.idp = OrbitalMapper(basis, method="e3tb", device=self.device) + if idp is not None: + assert idp == self.idp, "The basis of idp and basis should be the same." + else: + assert idp is not None, "Either basis or idp should be provided." + self.idp = idp + + def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): + # mask the data + + # data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.) + # data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.) + + if self.onsite_shift: + batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0])) + # assert batch.max() == 0, "The onsite shift is only supported for batchsize=1." + mu = data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \ + ref_data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] + if batch.max() == 0: # when batchsize is zero + mu = mu.mean().detach() + ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[ + AtomicDataDict.NODE_OVERLAP_KEY] + ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[ + AtomicDataDict.EDGE_OVERLAP_KEY] + elif batch.max() >= 1: + slices = [data["__slices__"]["pos"][i] - data["__slices__"]["pos"][i - 1] for i in + range(1, len(data["__slices__"]["pos"]))] + slices = [0] + slices + ndiag_batch = torch.stack([i.sum() for i in + self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split( + slices)]) + ndiag_batch = torch.cumsum(ndiag_batch, dim=0) + mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i + 1]].mean() for i in range(len(ndiag_batch) - 1)]) + mu = mu.detach() + ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[ + batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY] + edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, + device=self.device) + for i in range(1, batch.max().item() + 1): + edge_mu_index[data["__slices__"]["edge_index"][i]:data["__slices__"]["edge_index"][i + 1]] += i + ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu[ + edge_mu_index, None] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY] + + # onsite loss + pre_onsite = data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + tgt_onsite = ref_data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + + # hopping loss + pre_hopping = data[AtomicDataDict.EDGE_FEATURES_KEY][ + self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + tgt_hopping = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][ + self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + + pre = torch.cat([pre_onsite, pre_hopping], dim=0) + tgt = torch.cat([tgt_onsite, tgt_hopping], dim=0) + + total_loss = self.loss1(pre, tgt) + return total_loss + + @Loss.register("hamil_wt") class HamilLossWT(nn.Module): def __init__( diff --git a/dptb/postprocess/__init__.py b/dptb/postprocess/__init__.py index be849ddf..6f51a652 100644 --- a/dptb/postprocess/__init__.py +++ b/dptb/postprocess/__init__.py @@ -1,11 +1,11 @@ from .bandstructure import Band from .totbplas import TBPLaS from .write_block import write_block - +from .write_abacus_csr_file import write_blocks_to_abacus_csr __all__ = [ Band, TBPLaS, write_block, - + write_blocks_to_abacus_csr ] \ No newline at end of file diff --git a/dptb/postprocess/write_abacus_csr_file.py b/dptb/postprocess/write_abacus_csr_file.py new file mode 100644 index 00000000..d8c8aae4 --- /dev/null +++ b/dptb/postprocess/write_abacus_csr_file.py @@ -0,0 +1,207 @@ +import os +import lmdb +import pickle +import re +import numpy as np +from scipy.sparse import csr_matrix, coo_matrix +from collections import defaultdict +import ase.data +from scipy.linalg import block_diag +from dftio.constants import ABACUS2DFTIO + +# DFTIO -> ABACUS +DFTIO2ABACUS = {l: M.T.astype(np.float32) for l, M in ABACUS2DFTIO.items()} + +ORBITAL_MAP = {'s': 0, 'p': 1, 'd': 2, 'f': 3, 'g': 4, 'h': 5} +KEY_RE = re.compile(r'^\s*(-?\d+)[ _](-?\d+)[ _](-?\d+)[ _](-?\d+)[ _](-?\d+)\s*$') +H_FACTOR = 13.605698 # Ryd -> eV factor for Hamiltonian + + +def parse_basis_to_l_list(basis_str): + """'2s2p1d' or 'spd' -> [0,0,1,1,2].""" + if basis_str is None: + return [] + s = str(basis_str).strip().lower() + if s == "": + return [] + tokens = re.findall(r'(\d*)([spdfgh])', s) + lst = [] + for num, ch in tokens: + cnt = int(num) if num else 1 + if ch not in ORBITAL_MAP: + raise ValueError(f"Unsupported orbital '{ch}' in '{basis_str}'") + lst.extend([ORBITAL_MAP[ch]] * cnt) + return lst + + +def find_basis_for_Z_or_symbol(basis_dict, Z): + """Find basis string for atomic number Z (multiple key forms).""" + if Z in basis_dict: + return basis_dict[Z] + sym = ase.data.chemical_symbols[Z] + for key_try in (sym, sym.capitalize(), sym.upper(), str(Z)): + if key_try in basis_dict: + return basis_dict[key_try] + for k, v in basis_dict.items(): + if isinstance(k, str) and k.lower() == sym.lower(): + return v + return None + + +def transform_2_ABACUS(mat, l_lefts, l_rights): + """Transform block from DFTIO ordering to ABACUS ordering.""" + if max(*(list(l_lefts) + list(l_rights))) > 5: + raise NotImplementedError("Only support l = s..h.") + left_mats = [DFTIO2ABACUS[l] for l in l_lefts] + right_mats = [DFTIO2ABACUS[l] for l in l_rights] + left = block_diag(*left_mats) if left_mats else np.eye(0, dtype=np.float32) + right = block_diag(*right_mats) if right_mats else np.eye(0, dtype=np.float32) + return left @ mat @ right.T + + +def write_abacus_csr_format(matrix_dict, matrix_symbol, output_path, step=0): + """Write mapping 'Rx_Ry_Rz' -> csr_matrix into ABACUS text CSR.""" + if not matrix_dict: + print(f"Warning: empty matrix_dict for {matrix_symbol}") + return + first = next(iter(matrix_dict)) + norbits = matrix_dict[first].shape[0] + num_blocks = len(matrix_dict) + with open(output_path, 'w') as f: + f.write(f"STEP: {step}\n") + f.write(f"Matrix Dimension of {matrix_symbol}(R): {norbits}\n") + f.write(f"Matrix number of {matrix_symbol}(R): {num_blocks}\n") + for r_key, sparse_mat in matrix_dict.items(): + r_vector_str = r_key.replace('_', ' ') + nnz = int(sparse_mat.nnz) + f.write(f"{r_vector_str} {nnz}\n") + if nnz > 0: + np.savetxt(f, sparse_mat.data.reshape(1, -1), fmt='%.8e') + np.savetxt(f, sparse_mat.indices.reshape(1, -1), fmt='%d') + np.savetxt(f, sparse_mat.indptr.reshape(1, -1), fmt='%d') + else: + f.write("\n\n\n") + # print(f"Wrote {num_blocks} blocks to {output_path}") + + +def write_blocks_to_abacus_csr(atomic_numbers, basis_dict, blocks_dict, matrix_symbol, output_path, step=0): + """ + Entry function: + atomic_numbers: per-site Z array-like + basis_dict: parse_orbital_files result + blocks_dict: mapping 'i_j_Rx_Ry_Rz' -> small block (DFTIO ordering) + matrix_symbol: 'H'/'S'/'D' + """ + atomic_numbers = np.asarray(atomic_numbers, dtype=int) + if atomic_numbers.size == 0: + raise ValueError("empty atomic_numbers") + + # choose factor + factor = H_FACTOR if str(matrix_symbol).upper() == 'H' else 1.0 + + # element -> l-list + element_l_lists = {} + for Z in np.unique(atomic_numbers): + basis_str = find_basis_for_Z_or_symbol(basis_dict, int(Z)) + if basis_str is None: + element_l_lists[int(Z)] = [0] + else: + ll = parse_basis_to_l_list(basis_str) + element_l_lists[int(Z)] = ll if ll else [0] + + # site norbits + site_norbits = np.array([sum(2 * l + 1 for l in element_l_lists[int(Z)]) for Z in atomic_numbers], dtype=int) + site_norbits_cumsum = np.cumsum(site_norbits) + norbits = int(site_norbits_cumsum[-1]) + + # aggregate COO data per R + r_vector_coo = defaultdict(lambda: {'data': [], 'rows': [], 'cols': []}) + + for raw_key, small_block in blocks_dict.items(): + key = raw_key.decode() if isinstance(raw_key, (bytes, bytearray)) else str(raw_key) + m = KEY_RE.match(key) + if not m: + # skip unparseable keys + continue + i_site = int(m.group(1)); j_site = int(m.group(2)) + Rx = int(m.group(3)); Ry = int(m.group(4)); Rz = int(m.group(5)) + r_str = f"{Rx}_{Ry}_{Rz}" + + # l-lists + l_lefts = element_l_lists[int(atomic_numbers[i_site])] + l_rights = element_l_lists[int(atomic_numbers[j_site])] + + # get ndarray (support sparse objects) + if hasattr(small_block, "toarray"): + block_arr = small_block.toarray() + elif "torch" in str(type(small_block)): + if small_block.is_cuda: + block_arr = small_block.detach().cpu().numpy() + else: + block_arr = small_block.detach().numpy() + else: + block_arr = np.asarray(small_block) + if block_arr.size == 0: + continue + + # transform DFTIO -> ABACUS + transformed = transform_2_ABACUS(block_arr.astype(np.float32), l_lefts, l_rights) + + # offsets + row_offset = int(site_norbits_cumsum[i_site] - site_norbits[i_site]) + col_offset = int(site_norbits_cumsum[j_site] - site_norbits[j_site]) + + coo = coo_matrix(transformed) + if coo.nnz == 0: + continue + + # apply factor (H vs others) + r_vector_coo[r_str]['data'].append((coo.data.astype(np.float32) / factor)) + r_vector_coo[r_str]['rows'].append((coo.row + row_offset).astype(int)) + r_vector_coo[r_str]['cols'].append((coo.col + col_offset).astype(int)) + + # build final CSR dict + reassembled = {} + for r_str, parts in r_vector_coo.items(): + if not parts['data']: + full = csr_matrix((norbits, norbits), dtype=np.float32) + else: + data = np.concatenate(parts['data']).astype(np.float32) + rows = np.concatenate(parts['rows']).astype(int) + cols = np.concatenate(parts['cols']).astype(int) + full = csr_matrix((data, (rows, cols)), shape=(norbits, norbits)) + reassembled[r_str] = full + + write_abacus_csr_format(reassembled, matrix_symbol, output_path, step=step) + return reassembled, norbits + + +# demo main +if __name__ == "__main__": + LMDB_PATH = r'E:\deeptb\large_DeepTB\0909\0910_lmdb\train\data.28400.lmdb' + ORBITAL_PATH = r'E:\deeptb\basis_set_test\production_use_dzp\orb_upf\public' + + from dprep.dptb_dpdispatcher import parse_orbital_files + _, basis_dict = parse_orbital_files(ORBITAL_PATH) + + env = lmdb.open(LMDB_PATH, readonly=True, lock=False) + with env.begin() as txn: + rec = txn.get((0).to_bytes(length=4, byteorder='big')) + if rec is None: + raise RuntimeError("No record at index 0") + data = pickle.loads(rec) + env.close() + + atomic_numbers = np.array(data['atomic_numbers'], dtype=int) + + if 'hamiltonian' in data and data['hamiltonian']: + write_blocks_to_abacus_csr( + atomic_numbers=atomic_numbers, + basis_dict=basis_dict, + blocks_dict=data['hamiltonian'], + matrix_symbol='H', + output_path='data-HR-sparse_SPIN0.csr', + step=0 + ) + else: + print("No hamiltonian in record 0.") diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 8598156f..884c585a 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -115,9 +115,10 @@ def train_options(): doc_sliding_win_size = "Sliding window size for the average of the latest iterations' loss. Used for the reduce on plateau learning rate scheduler in case of the pairing of large dataset and small batch size. Default: `50`" doc_optimizer = "\ - The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ + The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `AdamW`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ - `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\ + - `AdamW`: [AdamW: Decoupled Weight Decay Regularization.](https://arxiv.org/abs/1711.05101)\n\n\ - `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\ - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " @@ -231,10 +232,11 @@ def LBFGS(): ] def optimizer(): - doc_type = "select type of optimizer, support type includes: `Adam`, `SGD` and `LBFGS`. Default: `Adam`" + doc_type = "select type of optimizer, support type includes: `Adam`, `AdamW`, `SGD` and `LBFGS`. Default: `Adam`" return Variant("type", [ Argument("Adam", dict, Adam()), + Argument("AdamW", dict, Adam()), Argument("SGD", dict, SGD()), Argument("RMSprop", dict, RMSprop()), Argument("LBFGS", dict, LBFGS()), @@ -635,6 +637,7 @@ def slem(): Argument("res_update_ratios", float, optional=True, default=0.5, doc="The ratios of residual update, should in (0,1)."), Argument("res_update_ratios_learnable", bool, optional=True, default=False, doc="Whether to make the ratios of residual update learnable."), Argument("universal", bool, optional=True, default=False, doc=doc_universal), + Argument("norm_eps", float, optional=True, default=1e-8, doc="eps in SeperableLayerNorm."), ] @@ -662,17 +665,21 @@ def sktb_prediction(): def e3tb_prediction(): - doc_scales_trainable = "whether to scale the trianing target." - doc_shifts_trainable = "whether to shift the training target." + doc_scales_trainable = "The scale parameter is from the statistics. Whether to train this parameter." + doc_shifts_trainable = "The scale parameter is from the statistics. Whether to train this parameter." doc_neurons = "neurons in the neural network." doc_activation = "activation function." doc_if_batch_normalized = "if to turn on batch normalization" + doc_scale_type = ("Which scale method to use. Can be no_scale, " + "scale_wo_back_grad (the scale parameter will not engage the back grad computation graph), " + "scale_w_back_grad (the scale parameter will engage the back grad computation graph)") nn = [ Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable), Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable), Argument("neurons", list, optional=True, default=None, doc=doc_neurons), Argument("activation", str, optional=True, default="tanh", doc=doc_activation), + Argument("scale_type", str, optional=True, default="scale_w_back_grad", doc=doc_scale_type), Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized), ] @@ -1756,9 +1763,10 @@ def normalize_skf2nnsk(data): doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`" doc_optimizer = "\ - The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ + The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `AdamW`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ - `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\ + - `AdamW`: [AdamW: Decoupled Weight Decay Regularization.](https://arxiv.org/abs/1711.05101)\n\n\ - `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\ - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " diff --git a/dptb/utils/tools.py b/dptb/utils/tools.py index 8c8e5f1b..d276be9f 100644 --- a/dptb/utils/tools.py +++ b/dptb/utils/tools.py @@ -125,8 +125,6 @@ def update_dict_with_warning(dict_input, update_list, update_value): return reconstruct_dict(flatten_input_dict) - - def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -138,6 +136,8 @@ def setup_seed(seed): def get_optimizer(type: str, model_param, lr: float, **options: dict): if type == 'Adam': optimizer = optim.Adam(params=model_param, lr=lr, **options) + elif type == 'AdamW': + optimizer = optim.AdamW(params=model_param, lr=lr, **options) elif type == 'SGD': optimizer = optim.SGD(params=model_param, lr=lr, **options) elif type == 'RMSprop': @@ -145,7 +145,7 @@ def get_optimizer(type: str, model_param, lr: float, **options: dict): elif type == 'LBFGS': optimizer = optim.LBFGS(params=model_param, lr=lr, **options) else: - raise RuntimeError("Optimizer should be Adam/SGD/RMSprop, not {}".format(type)) + raise RuntimeError("Optimizer should be Adam/AdamW/SGD/RMSprop, not {}".format(type)) return optimizer def get_lr_scheduler(type: str, optimizer: optim.Optimizer, **sch_options): diff --git a/pyproject.toml b/pyproject.toml index 0e455e5f..6dc62de6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ pyyaml = "*" future = "*" dargs = "0.4.4" xitorch = "0.3.0" -e3nn = ">=0.5.1" +e3nn = ">=0.5.1,<=0.5.3" torch-runstats = "0.2.0" torch_scatter = "2.1.2" torch_geometric = ">=2.4.0"