From e339d518af5c2315e8c542f73c022328d45d8ac8 Mon Sep 17 00:00:00 2001 From: floatingCatty Date: Mon, 21 Jul 2025 00:14:48 +0800 Subject: [PATCH 1/2] fix small type in doc ana update --- docs/advanced/e3tb/loss_analysis.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/advanced/e3tb/loss_analysis.md b/docs/advanced/e3tb/loss_analysis.md index 0a3bd592..14b3e9f4 100644 --- a/docs/advanced/e3tb/loss_analysis.md +++ b/docs/advanced/e3tb/loss_analysis.md @@ -39,7 +39,7 @@ loader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=0) for data in tqdm(loader, desc="doing error analysis"): with torch.no_grad(): data = data.to("cuda") - batch_info = data.get_batch_info() + batch_info = data.get_batchinfo() ref_data = AtomicData.to_AtomicDataDict(data) data = model(ref_data) data.update(batch_info) From 098c88a03cfb89cf362e78008b43cde685596d03 Mon Sep 17 00:00:00 2001 From: floatingCatty Date: Wed, 7 Jan 2026 12:04:58 -0500 Subject: [PATCH 2/2] add distributed AtomicData support and graph partition with hilbert space filling curve and graph based partition using metis --- dptb/data/AtomicData.py | 252 +++++++++++++++++++++++++++++++- dptb/data/_keys.py | 14 ++ dptb/postprocess/write_block.py | 5 +- dptb/tests/test_dist_data.py | 38 +++++ 4 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 dptb/tests/test_dist_data.py diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py index 46b5dd37..7fd1852b 100644 --- a/dptb/data/AtomicData.py +++ b/dptb/data/AtomicData.py @@ -15,6 +15,8 @@ from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress +from scipy.sparse._sparsetools import coo_tocsr, csr_sum_duplicates +from scipy.sparse._sputils import downcast_intp_index import torch import e3nn.o3 @@ -54,6 +56,7 @@ AtomicDataDict.NODE_ATTRS_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY, AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.ATOM_INDEX_KEY, AtomicDataDict.FORCE_KEY, AtomicDataDict.PER_ATOM_ENERGY_KEY, AtomicDataDict.NODE_HAMILTONIAN_KEY, @@ -110,7 +113,12 @@ AtomicDataDict.EIGENVECTOR_KEY, # new # should be nested AtomicDataDict.ENERGY_WINDOWS_KEY, # new, AtomicDataDict.BAND_WINDOW_KEY, # new, - AtomicDataDict.NODE_SOC_SWITCH_KEY # new + AtomicDataDict.NODE_SOC_SWITCH_KEY, # new + AtomicDataDict.CLUSTER_CONNECTIVITY, + AtomicDataDict.CLUSTER_GHOST_LIST, + AtomicDataDict.CLUSTER_NODE_RANGE, + AtomicDataDict.CLUSTER_EDGE_RANGE, + AtomicDataDict.PARTITION } _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) @@ -811,6 +819,248 @@ def __cat_dim__(self, key, value): return None else: return 0 # cat along node/edge dimension + + def partition_graph_hilbert(self, npart, comm=None, split=False): + """ + Partition the graph using Hilbert curve (Z-curve approximation) for scalability. + This avoids gathering the full graph on rank 0. + """ + from mpi4py import MPI + if comm is None: + comm = MPI.COMM_WORLD + + rank = comm.Get_rank() + + # 1. Get positions (assume they are available on all ranks or at least rank 0 for now) + # For true scalability, this should be distributed, but AtomicData currently holds full pos. + # If AtomicData is initialized on all ranks with full pos, this is fast. + + pos = self.pos + + # 2. Discretize positions to integer grid + # Normalize to [0, 1] first + pos_min = pos.min(axis=0).values + pos_max = pos.max(axis=0).values + pos_norm = (pos - pos_min) / (pos_max - pos_min + 1e-9) + + n_bits = 21 + max_val = (1 << n_bits) - 1 + coords = (pos_norm * max_val).type(torch.int32) + + # 3. Compute Z-curve index (Morton code) + # Interleave bits: ... z1 y1 x1 z0 y0 x0 + x = coords[:, 0].type(torch.long) + y = coords[:, 1].type(torch.long) + z = coords[:, 2].type(torch.long) + + code = torch.zeros(x.shape, dtype=torch.long) + for i in range(n_bits): + code |= ((x >> i) & 1) << (3*i) + code |= ((y >> i) & 1) << (3*i + 1) + code |= ((z >> i) & 1) << (3*i + 2) + + # 4. Sort atoms by Z-curve index + sort_index = torch.argsort(code).type(torch.int32) + + # 5. Partition into equal chunks + partition = torch.zeros(self.num_nodes, dtype=torch.int32) + chunk_size = self.num_nodes // npart + for i in range(npart): + start = i * chunk_size + end = (i + 1) * chunk_size if i < npart - 1 else self.num_nodes + partition[sort_index[start:end]] = i + + # Re-order partition array to match original atom order? + # No, we want to re-order atoms to match partition order (locality). + # So we keep sort_index as the mapping. + + # The existing partition_graph logic sorts atoms by partition ID. + # Here sort_index ALREADY sorts by partition ID (since partition is monotonic with sort_index). + + # Broadcast if needed (though here we computed it on all ranks if self.pos is on all ranks) + # If self.pos is only on rank 0, we need to broadcast. + # Assuming self.pos is available. + partition = partition[sort_index] + + self._finalize_partition(sort_index, partition, npart, comm, split) + + return True + + def partition_graph(self, npart, node_weights=None, comm=None, split=False): + import pymetis + from mpi4py import MPI + + if comm is None: + comm = MPI.COMM_WORLD + + rank = comm.Get_rank() + + if rank == 0: + idx_dtype = self.edge_index.dtype + nnz = self.edge_index.shape[1] + data = torch.zeros(nnz, device=self.edge_index.device) # just fill in the need + + # sum duplicate + # order = np.lexsort(self.edge_index[::-1]) + order = torch.argsort(torch.unique(self.edge_index, dim=-1, sorted=True, return_inverse=True)[1]) + coords = tuple(idx[order] for idx in self.edge_index) + # data = data[order] + diffs = torch.stack([(idx[1:] != idx[:-1]) for idx in coords], dim=0) + unique_mask_tail = torch.any(diffs, dim=0) # (nnz-1,) bool + + # unique_mask = np.append(True, unique_mask) + unique_mask = torch.cat( + [torch.tensor([True], device=unique_mask_tail.device), unique_mask_tail], + dim=0 + ) + + coords = tuple(idx[unique_mask] for idx in coords) + # unique_inds, = torch.nonzero(unique_mask) + seg_id = torch.cumsum(unique_mask.to(torch.int64), dim=0) - 1 # shape (nnz,) + nseg = int(seg_id[-1].item()) + 1 + data2 = torch.zeros(nseg, dtype=data.dtype, device=data.device) + data2.scatter_add_(0, seg_id, data) # sums within each segment + data = data2 + # data = np.add.reduceat(data, downcast_intp_index(unique_inds), dtype=data.dtype) + nnz = data.numel() + + major, minor = coords + + M = self.num_nodes + counts = torch.bincount(major.to(torch.int64), minlength=M) + indptr = torch.empty(M + 1, dtype=idx_dtype, device=major.device) + indptr[0] = 0 + indptr[1:] = torch.cumsum(counts.to(indptr.dtype), dim=0) + + indices = minor.to(dtype=idx_dtype).contiguous() + + + if node_weights is not None: + assert node_weights.numel() == M and node_weights.ndim == 1 + assert int(node_weights.min().item()) > 0 + assert node_weights.dtype == torch.int32 + eweights = (node_weights[major] * node_weights[minor]).contiguous() + else: + eweights = None + + opt = pymetis.Options() + opt.ncuts = 40 + opt.niter = 50 + opt.minconn = 1 + opt.ufactor = 100 + # opt.contig = 1 + + indptr_np = indptr.detach().cpu().numpy() + indices_np = indices.detach().cpu().numpy() + vweights_np = None if node_weights is None else node_weights.detach().cpu().numpy() + eweights_np = None if eweights is None else eweights.detach().cpu().numpy() + + _, partition = pymetis.part_graph(npart, xadj=indptr_np, adjncy=indices_np, vweights=vweights_np, eweights=eweights_np, options=opt) + # sort the partition + partition = np.asarray(partition, dtype=np.int32) + sort_index = np.argsort(partition).astype(np.int32).ravel() # also the new atom_index + partition = partition[sort_index].ravel() + else: + sort_index = np.zeros(self.num_nodes, dtype=np.int32) + partition = np.zeros(self.num_nodes, dtype=np.int32) + + comm.Bcast(sort_index, root=0) + comm.Bcast(partition, root=0) + + sort_index = torch.from_numpy(sort_index) + partition = torch.from_numpy(partition) + + self._finalize_partition(sort_index, partition, npart, comm, split) + + return True + + def _finalize_partition(self, sort_index, partition, npart, comm=None, split=False): + if comm is None: + from mpi4py import MPI + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + # sort the attrs + for fields in _NODE_FIELDS: + if hasattr(self, fields): + setattr(self, fields, getattr(self, fields)[sort_index]) + setattr(self, AtomicDataDict.ATOM_INDEX_KEY, sort_index) + setattr(self, AtomicDataDict.PARTITION, partition) + + if hasattr(self, AtomicDataDict.EDGE_INDEX_KEY): + atom_index_map = torch.zeros(len(sort_index), dtype=torch.long, device=self.edge_index.device) + atom_index_map[sort_index] = torch.arange(len(sort_index), dtype=torch.long, device=self.edge_index.device) + + self.edge_index = atom_index_map[self.edge_index] + edge_sort_index = torch.argsort(self.edge_index[0]) + self.edge_index = self.edge_index[:,edge_sort_index] + + for fields in _EDGE_FIELDS: + if hasattr(self, fields): + setattr(self, fields, getattr(self, fields)[edge_sort_index]) + + # get the range of each cluster's node and edge + n_range = torch.zeros((npart, 2), dtype=torch.int32, device=self.edge_index.device) + e_range = torch.zeros((npart, 2), dtype=torch.int32, device=self.edge_index.device) + + # construct connectivity table + connectivity = [[]]*npart + ghost_list = [[]]*npart + edge_part = partition[self.edge_index[0]] + nstart = 0 + estart = 0 + + for i in range(npart): + natom = sum(partition == i) + n_range[i][0] = nstart + n_range[i][1] = nstart + natom + nstart += natom + + if hasattr(self, AtomicDataDict.EDGE_INDEX_KEY): + cst_nbr = self.edge_index[1][edge_part == i] + cst_slf = self.edge_index[0][edge_part == i] + cst_slf = torch.unique(cst_slf) + e_range[i][0] = estart + e_range[i][1] = estart + len(cst_nbr) + estart += len(cst_nbr) + cst_nbr = torch.unique(cst_nbr) + # remove self from the ghost and conn + ghost_list[i] = list(set(cst_nbr.detach().numpy().tolist()) - set(cst_slf.detach().numpy().tolist())) + connectivity[i] = list(set(torch.unique(partition[cst_nbr]).detach().numpy().tolist()) - set([i])) + + setattr(self, AtomicDataDict.CLUSTER_NODE_RANGE, n_range) + if hasattr(self, AtomicDataDict.EDGE_INDEX_KEY): + setattr(self, AtomicDataDict.CLUSTER_EDGE_RANGE, e_range) + setattr(self, AtomicDataDict.CLUSTER_CONNECTIVITY, connectivity) + setattr(self, AtomicDataDict.CLUSTER_GHOST_LIST, ghost_list) + + if split: + assert comm.Get_size() == npart, "The split can only be called with npart=MPI size" + self.split(rank) + + + def split(self, i): + """ + return the graph of the ith's cluster + """ + assert hasattr(self, AtomicDataDict.PARTITION), "The graph need to be partitioned first before split" + # sort the atom + + nstart, nend = getattr(self, AtomicDataDict.CLUSTER_NODE_RANGE)[i] + estart, eend = getattr(self, AtomicDataDict.CLUSTER_EDGE_RANGE)[i] + + for fields in _NODE_FIELDS: + if hasattr(self, fields): + setattr(self, fields, getattr(self, fields)[nstart:nend]) + self.__num_nodes__ = nend - nstart + + for fields in _EDGE_FIELDS: + if hasattr(self, fields): + setattr(self, fields, getattr(self, fields)[estart:eend]) + + setattr(self, AtomicDataDict.EDGE_INDEX_KEY, getattr(self, AtomicDataDict.EDGE_INDEX_KEY)[:,estart:eend]) + + return True def without_nodes(self, which_nodes): """Return a copy of ``self`` with ``which_nodes`` removed. diff --git a/dptb/data/_keys.py b/dptb/data/_keys.py index 2710798c..32379c9f 100644 --- a/dptb/data/_keys.py +++ b/dptb/data/_keys.py @@ -31,6 +31,19 @@ # [n_kpoints, 3] or [n_batch, nkpoints, 3] tensor KPOINT_KEY = "kpoint" + +# keys for distributed AtomicData + +PARTITION = "partition" + +CLUSTER_CONNECTIVITY = "cluster_connectivity" + +CLUSTER_NODE_RANGE = "cluster_node_range" + +CLUSTER_EDGE_RANGE = "cluster_edge_range" + +CLUSTER_GHOST_LIST = "cluster_ghost_list" + HAMILTONIAN_KEY = "hamiltonian" OVERLAP_KEY = "overlap" @@ -40,6 +53,7 @@ ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers" # [n_atom, 1] long tensor ATOM_TYPE_KEY: Final[str] = "atom_types" +ATOM_INDEX_KEY: Final[str] = "atom_index" # [n_batch, n_kpoint, n_orb] ENERGY_EIGENVALUE_KEY: Final[str] = "eigenvalue" EIGENVECTOR_KEY: Final[str] = "eigenvector" diff --git a/dptb/postprocess/write_block.py b/dptb/postprocess/write_block.py index 812b26b1..1f4f0366 100644 --- a/dptb/postprocess/write_block.py +++ b/dptb/postprocess/write_block.py @@ -17,7 +17,8 @@ def write_block( data: Union[AtomicData, ase.Atoms, str], model: torch.nn.Module, AtomicData_options: dict={}, - device: Union[str, torch.device]=None + device: Union[str, torch.device]=None, + overlap: bool=False, ): model.eval() @@ -40,7 +41,7 @@ def write_block( # set the kpoint of the AtomicData data = model(data) - block = feature_to_block(data=data, idp=model.idp) + block = feature_to_block(data=data, idp=model.idp, overlap=overlap) return block diff --git a/dptb/tests/test_dist_data.py b/dptb/tests/test_dist_data.py new file mode 100644 index 00000000..bbe14b23 --- /dev/null +++ b/dptb/tests/test_dist_data.py @@ -0,0 +1,38 @@ +import pytest +from dptb.data import AtomicData, _keys +import ase.io as io +try: + import mpi4py + _MPI = True +except: + ImportError("The test is bypassed since the lack of MPI") + _MPI = False +import os +from pathlib import Path + +rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data") + + + +def test_hilbert_part(): + if _MPI: + data = AtomicData.from_ase(io.read(os.path.join(rootdir, "hBN", "hBN.vasp")), r_max=4.0) + data.partition_graph_hilbert(2, split=False) + print(data[_keys.CLUSTER_CONNECTIVITY]) + print(data[_keys.CLUSTER_NODE_RANGE]) + print(data[_keys.CLUSTER_GHOST_LIST]) + +def test_metis_part(): + if _MPI: + data = AtomicData.from_ase(io.read(os.path.join(rootdir, "hBN", "hBN.vasp")), r_max=4.0) + data.partition_graph(2, split=False) + print(data[_keys.CLUSTER_CONNECTIVITY]) + print(data[_keys.CLUSTER_NODE_RANGE]) + print(data[_keys.CLUSTER_GHOST_LIST]) + + +# if __name__ == "__main__": +# test_hilbert_part() +# test_metis_part() + + \ No newline at end of file