diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py index b70ac7fc..03d8dc50 100644 --- a/dptb/entrypoints/train.py +++ b/dptb/entrypoints/train.py @@ -4,6 +4,7 @@ from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer, TensorBoardMonitor from dptb.plugins.train_logger import Logger from dptb.utils.argcheck import normalize, collect_cutoffs, chk_avg_per_iter +from dptb.utils.orbital_parser import parse_orbital_file from dptb.plugins.saver import Saver from typing import Dict, List, Optional, Any from dptb.utils.tools import j_loader, setup_seed, j_must_have @@ -90,6 +91,35 @@ def train( jdata = j_loader(INPUT) jdata = normalize(jdata) + + # Validate and process orbital files in basis + if jdata.get("common_options") and jdata["common_options"].get("basis"): + orbital_files_content = {} + for elem, value in jdata["common_options"]["basis"].items(): + if isinstance(value, str) and os.path.isfile(value): + # strict check for e3tb method + # Check if model_options exists and has prediction method + model_opts = jdata.get("model_options", {}) + pred_opts = model_opts.get("prediction", {}) + # normalize might handle defaults, but safely check here + if pred_opts.get("method", "e3tb") != "e3tb": + raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}") + + # Also checking if we are in a mixed model which might be different, + # but usually orbital files for basis imply the main basis handling. + # transform logic + try: + parsed_basis = parse_orbital_file(value) + with open(value, 'r') as f: + orbital_files_content[elem] = f.read() + + jdata["common_options"]["basis"][elem] = parsed_basis + log.info(f"Parsed orbital file for {elem}: {value} -> {parsed_basis}") + except Exception as e: + raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}") + + if orbital_files_content: + jdata["common_options"]["orbital_files_content"] = orbital_files_content # update basis if init_model or restart # update jdata # this is not necessary, because if we init model from checkpoint, the build_model will load the model_options from checkpoints if not provided diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 13fbf984..ba8caa61 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -290,6 +290,8 @@ def __init__( device=self.device, ) + if kwargs.get("orbital_files_content"): + self.orbital_files_content = kwargs["orbital_files_content"] def forward(self, data: AtomicDataDict.Type): if data.get(AtomicDataDict.EDGE_TYPE_KEY, None) is None: diff --git a/dptb/nn/energy.py b/dptb/nn/energy.py index 6103bad0..c1c2797e 100644 --- a/dptb/nn/energy.py +++ b/dptb/nn/energy.py @@ -11,6 +11,16 @@ from dptb.data.transforms import OrbitalMapper from dptb.data import AtomicDataDict import logging +try: + from dptb.utils.pardiso_wrapper import PyPardisoSolver + from dptb.utils.feast_wrapper import FeastSolver + from scipy.sparse.linalg import eigsh, LinearOperator +except ImportError: + PyPardisoSolver = None + FeastSolver = None + eigsh = None + LinearOperator = None + log = logging.getLogger(__name__) class Eigenvalues(nn.Module): @@ -113,6 +123,178 @@ def forward(self, data[AtomicDataDict.KPOINT_KEY] = kpoints return data + +class PardisoEig: + def __init__(self, sigma: float = 0.0, neig: int = 10, mode: str = 'normal'): + """ + Solver using Pardiso for shift-invert eigenvalue problems. + + Args: + sigma: Shift value (target energy). + neig: Number of eigenvalues to solve for. + mode: Eigsh mode ('normal', 'buckling', 'cayley'). + """ + if PyPardisoSolver is None or eigsh is None: + raise ImportError("PardisoEig requires MKL (pypardiso) and scipy.sparse.linalg") + + self.sigma = sigma + self.neig = neig + self.mode = mode + + + def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False): + """ + Solve eigenvalues for given k-points. + + Args: + h_container: vbcsr.ImageContainer for Hamiltonian. + s_container: vbcsr.ImageContainer for Overlap (can be None). + kpoints: Array of k-points (Nk, 3). + return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False. + + Returns: + list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True). + """ + + # Ensure kpoints is numpy array + if isinstance(kpoints, torch.Tensor): + kpoints = kpoints.cpu().numpy() + + eigvals_list = [] + eigvecs_list = [] + + for k in kpoints: + hk = h_container.sample_k(k, symm=True) + + if s_container is not None: + sk = s_container.sample_k(k, symm=True) + hk -= self.sigma * sk + A = hk.to_scipy(format="csr") + M = sk + else: + hk.shift(-self.sigma) + A = hk.to_scipy(format="csr") + M = None + + A.sort_indices() + A.sum_duplicates() + N = A.shape[0] + + # Try PARDISO first, fall back to scipy SuperLU if PARDISO fails + # (MKL PARDISO has a known bug with certain block-structured patterns) + solver = PyPardisoSolver(mtype=13) + solver.factorize(A) + + def matvec(b): + return solver.solve(A, b) + + Op = LinearOperator((N, N), matvec=matvec, dtype=A.dtype) + + try: + # Use larger NCV to help convergence, especially for clustered eigenvalues + ncv = max(2*self.neig + 1, 20) + vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=0.0, OPinv=Op, mode=self.mode, which="LM", ncv=ncv) + except Exception: + # Retry with larger NCV if ARPACK fails (e.g. error 3: No shifts could be applied) + # This often happens when eigenvalues are clustered near the shift + ncv = max(5*self.neig, 50) + vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=0.0, OPinv=Op, mode=self.mode, which="LM", ncv=ncv) + + eigvals_list.append(vals + self.sigma) + if return_eigenvectors: + eigvecs_list.append(vecs) + + if return_eigenvectors: + return eigvals_list, eigvecs_list + else: + return eigvals_list + +class FEASTEig: + def __init__(self, emin: float = -1.0, emax: float = 1.0, m0: Optional[int] = None, + max_refinement: int = 3, uplo: str = 'U', extract_triangular: bool = True): + """ + Solver using FEAST algorithm for finding eigenvalues in a given interval. + + Args: + emin, emax: Energy interval [emin, emax]. + m0: Initial subspace size estimate. + max_refinement: Number of refinements if subspace is too small. + uplo: 'U' (Upper) or 'L' (Lower) triangular part to use. + extract_triangular: Whether to extract triangular part automatically. + """ + + if FeastSolver is None: + raise ImportError("FEAST solver not available") + + self.emin = emin + self.emax = emax + self.m0 = m0 + self.max_refinement = max_refinement + self.uplo = uplo + self.extract_triangular = extract_triangular + + # Initialize solver to check availability + try: + self.solver = FeastSolver() + except ImportError as e: + raise ImportError(f"FEAST solver not available: {e}") from e + except Exception as e: + raise RuntimeError(f"Failed to initialize FeastSolver: {e}") from e + + def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False): + """ + Solve eigenvalues for given k-points using FEAST. + + Args: + h_container: Container for Hamiltonian (must support sample_k().to_scipy()). + s_container: Container for Overlap (can be None). + kpoints: Array of k-points. + return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False. + + Returns: + list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True). + """ + if isinstance(kpoints, torch.Tensor): + kpoints = kpoints.cpu().numpy() + + eigvals_list = [] + eigvecs_list = [] + + for k in kpoints: + # Get Hamiltonian and Overlap at k + # Assuming h_container.sample_k returns object with .to_scipy() + hk_obj = h_container.sample_k(k, symm=True) + if hasattr(hk_obj, 'to_scipy'): + hk = hk_obj.to_scipy(format="csr") + else: + # Fallback if it checks sparse type + hk = hk_obj + + if s_container is not None: + sk_obj = s_container.sample_k(k, symm=True) + if hasattr(sk_obj, 'to_scipy'): + sk = sk_obj.to_scipy(format="csr") + else: + sk = sk_obj + else: + sk = None + + # Solve + evals, vecs = self.solver.solve( + hk, M=sk, emin=self.emin, emax=self.emax, + m0=self.m0, max_refinement=self.max_refinement, + uplo=self.uplo, extract_triangular=self.extract_triangular + ) + + eigvals_list.append(evals) + if return_eigenvectors: + eigvecs_list.append(vecs) + + if return_eigenvectors: + return eigvals_list, eigvecs_list + else: + return eigvals_list + class Eigh(nn.Module): def __init__( diff --git a/dptb/nn/hr2hR.py b/dptb/nn/hr2hR.py new file mode 100644 index 00000000..d467f3e2 --- /dev/null +++ b/dptb/nn/hr2hR.py @@ -0,0 +1,167 @@ +import torch +from dptb.utils.constants import h_all_types, anglrMId, atomic_num_dict, atomic_num_dict_r +from typing import Tuple, Union, Dict +from dptb.data.transforms import OrbitalMapper +from dptb.data import AtomicDataDict +import re +from dptb.utils.tools import float2comlex, tdtype2ndtype +import numpy as np +try: + from vbcsr import ImageContainer + from vbcsr import AtomicData as AtomicData_vbcsr +except ImportError: + ImageContainer = None + AtomicData_vbcsr = None + +class Hr2HR: + def __init__( + self, + basis: Dict[str, Union[str, list]]=None, + idp: Union[OrbitalMapper, None]=None, + edge_field: str = AtomicDataDict.EDGE_FEATURES_KEY, + node_field: str = AtomicDataDict.NODE_FEATURES_KEY, + overlap: bool = False, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu") + ): + + if ImageContainer is None or AtomicData_vbcsr is None: + raise ImportError("vbcsr") + + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + self.dtype = dtype + self.device = device + self.overlap = overlap + self.ctype = float2comlex(dtype) + + if basis is not None: + self.idp = OrbitalMapper(basis, method="e3tb", device=self.device) + if idp is not None: + assert idp == self.idp, "The basis of idp and basis should be the same." + else: + assert idp is not None, "Either basis or idp should be provided." + assert idp.method == "e3tb", "The method of idp should be e3tb." + self.idp = idp + + self.basis = self.idp.basis + self.idp.get_orbpair_maps() + self.idp.get_orbpair_soc_maps() + + self.edge_field = edge_field + self.node_field = node_field + + def __call__(self, data: AtomicDataDict.Type): + + # construct bond wise hamiltonian block from obital pair wise node/edge features + # we assume the edge feature have the similar format as the node feature, which is reduced from orbitals index oj-oi with j>i + if AtomicDataDict.ATOM_TYPE_KEY not in data: + self.idp(data) + + orbpair_hopping = data[self.edge_field] + orbpair_onsite = data.get(self.node_field) + natom = orbpair_onsite.shape[0] + nedge = orbpair_hopping.shape[0] + + bondwise_hopping = torch.zeros((nedge, self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.dtype, device=self.device) + onsite_block = torch.zeros((natom, self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device) + + atom_type = data[AtomicDataDict.ATOM_TYPE_KEY].flatten() + edge_shift_vec = data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] + + soc = data.get(AtomicDataDict.NODE_SOC_SWITCH_KEY, False) + ndtype = np.float64 + if isinstance(soc, torch.Tensor): + soc = soc.all() + + if soc: + # this soc only support sktb. + orbpair_soc = data[AtomicDataDict.NODE_SOC_KEY] + soc_upup_block = torch.zeros((natom, self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device) + soc_updn_block = torch.zeros((natom, self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device) + ndtype = np.complex128 + + spin_factor = 2 if soc else 1 + + with torch.no_grad(): + ist = 0 + for i,iorb in enumerate(self.idp.full_basis): + jst = 0 + li = anglrMId[re.findall(r"[a-zA-Z]+", iorb)[0]] + for j,jorb in enumerate(self.idp.full_basis): + orbpair = iorb + "-" + jorb + lj = anglrMId[re.findall(r"[a-zA-Z]+", jorb)[0]] + + # constructing hopping blocks + if iorb == jorb: + factor = 1.0 + else: + factor = 2.0 + + if i <= j: + bondwise_hopping[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_hopping[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1) + onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1) + + if soc and i==j and not self.overlap: + # For now, The SOC part is only added to Hamiltonian, not overlap matrix. + # For now, The SOC only has onsite contribution. + soc_updn_tmp = orbpair_soc[:,self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1)) + soc_upup_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,:2*lj+1] + soc_updn_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,2*lj+1:] + + jst += 2*lj+1 + ist += 2*li+1 + onsite_block = onsite_block.to("cpu") + bondwise_hopping = bondwise_hopping.to("cpu") + if soc and not self.overlap: + # store for later use + # for now, soc only contribute to Hamiltonain, thus for overlap not store soc parts. + soc_upup_block = soc_upup_block.to("cpu") + soc_updn_block = soc_updn_block.to("cpu") + + adata = AtomicData_vbcsr.from_distributed( + natom, natom, 0, nedge, nedge, + list(range(natom)), data[AtomicDataDict.ATOM_TYPE_KEY], data[AtomicDataDict.EDGE_INDEX_KEY].T, self.idp.atom_norb, data[AtomicDataDict.EDGE_CELL_SHIFT_KEY], + data[AtomicDataDict.CELL_KEY], data[AtomicDataDict.POSITIONS_KEY] + ) + image_container = ImageContainer(adata, ndtype) + + for i, oblock in enumerate(onsite_block): + mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]] + masked_oblock = oblock[mask][:,mask] + nrol, ncol = masked_oblock.shape + full_block = np.zeros([nrol*spin_factor, ncol*spin_factor], dtype=ndtype) + if soc: + full_block[:nrol,:ncol] = masked_oblock + full_block[nrol:,ncol:] = masked_oblock + if not self.overlap: + full_block[:nrol,ncol:] = soc_updn_block[i,mask][:,mask] + full_block[nrol:,:ncol] = soc_updn_block[i,mask][:,mask].conj() + full_block[:nrol,:ncol] += soc_upup_block[i,mask][:,mask] + full_block[nrol:,ncol:] += soc_upup_block[i,mask][:,mask].conj() + else: + full_block[:nrol,:ncol] = masked_oblock + full_block = np.ascontiguousarray(full_block) + + image_container.add_block(g_row=i, g_col=i, data=full_block, R=None, mode="insert") + + for i, hblock in enumerate(bondwise_hopping): + iatom = data[AtomicDataDict.EDGE_INDEX_KEY][0][i].item() + jatom = data[AtomicDataDict.EDGE_INDEX_KEY][1][i].item() + imask = self.idp.mask_to_basis[atom_type[iatom]] + jmask = self.idp.mask_to_basis[atom_type[jatom]] + masked_hblock = hblock[imask][:,jmask] + nrol, ncol = masked_hblock.shape + full_block = np.zeros([nrol*spin_factor, ncol*spin_factor], dtype=ndtype) + if soc: + full_block[:nrol,:ncol] = masked_hblock + full_block[nrol:,ncol:] = masked_hblock + else: + full_block[:nrol,:ncol] = masked_hblock + full_block = np.ascontiguousarray(full_block) + + image_container.add_block(g_row=iatom, g_col=jatom, data=full_block, R=edge_shift_vec[i], mode="insert") + + image_container.assemble() + + return image_container \ No newline at end of file diff --git a/dptb/postprocess/ovp2c.py b/dptb/postprocess/ovp2c.py index 246623be..76d98691 100644 --- a/dptb/postprocess/ovp2c.py +++ b/dptb/postprocess/ovp2c.py @@ -1,12 +1,21 @@ -import op2c import torch from typing import Union - from dptb.data.transforms import OrbitalMapper from dptb.data import AtomicDataDict from dptb.data import _keys from dptb.data.interfaces import block_to_feature from dptb.data.AtomicDataDict import with_edge_vectors +import numpy as np +try: + from vbcsr import ImageContainer + from vbcsr import AtomicData as AtomicData_vbcsr +except ImportError: + print("VBCSR is not installed, therefore, the compute_overlap_image is not supported.") + +try: + import op2c +except ImportError: + print("OP2C is not installed, therefore, the compute_overlap is not supported.") def compute_overlap(data: AtomicDataDict, idp: OrbitalMapper, orb_dir, orb_names): ntype = idp.num_types @@ -56,5 +65,56 @@ def compute_overlap(data: AtomicDataDict, idp: OrbitalMapper, orb_dir, orb_names data[_keys.EDGE_OVERLAP_KEY].to(device) return data + +def compute_overlap_image(data: AtomicDataDict, idp: OrbitalMapper, orb_dir, orb_names): + ntype = idp.num_types + op = op2c.Op2C( + ntype=ntype, + nspin=1, # for current usage + lspinorb=False, + orb_dir=orb_dir, + orb_name=orb_names + ) + + if _keys.EDGE_VECTORS_KEY not in data: + data = with_edge_vectors(data) + if _keys.ATOM_TYPE_KEY not in data: + idp(data) + + edge_index = data[_keys.EDGE_INDEX_KEY].cpu() + atom_types = data[_keys.ATOM_TYPE_KEY].cpu() + edge_vectors = data[_keys.EDGE_VECTORS_KEY].cpu() + cell_shifts = data[_keys.EDGE_CELL_SHIFT_KEY].cpu() + + natom = len(atom_types) + nedge = len(edge_index[0]) + adata = AtomicData_vbcsr.from_distributed( + natom, natom, 0, nedge, nedge, + list(range(natom)), atom_types, edge_index.T, idp.atom_norb, cell_shifts, + data[_keys.CELL_KEY].cpu(), data[_keys.POSITIONS_KEY].cpu() + ) + image_container = ImageContainer(adata, np.float64) + + for k in range(nedge): + i, j = edge_index[:, k].tolist() + itype, jtype = atom_types[i].item(), atom_types[j].item() + + Rij = edge_vectors[k] * 1.8897259886 # angstrom to bohr + Rij = Rij.tolist() + Rvec = cell_shifts[k].int().tolist() + + inorb = idp.atom_norb[itype] + jnorb = idp.atom_norb[jtype] + + S = op.overlap(itype, jtype, Rij, is_transpose=False) + image_container.add_block(i, j, S.reshape(inorb, jnorb), Rvec) - \ No newline at end of file + for k in range(natom): + itype = atom_types[k].item() + S = op.overlap(itype, itype, [0, 0, 0], is_transpose=False) + inorb = idp.atom_norb[itype] + image_container.add_block(k, k, S.reshape(inorb, inorb), [0, 0, 0]) + + image_container.assemble() + + return image_container diff --git a/dptb/postprocess/unified/calculator.py b/dptb/postprocess/unified/calculator.py index 39c85ec0..1da1f492 100644 --- a/dptb/postprocess/unified/calculator.py +++ b/dptb/postprocess/unified/calculator.py @@ -7,6 +7,11 @@ from dptb.nn.energy import Eigenvalues, Eigh from dptb.data.interfaces.ham_to_feature import feature_to_block from dptb.nn.hr2hk import HR2HK +from dptb.nn.hr2hR import Hr2HR +import logging +import os + +log = logging.getLogger(__name__) class HamiltonianCalculator(ABC): """Abstract Base Class defining the interface for a Hamiltonian calculator.""" @@ -39,6 +44,19 @@ def get_hr(self, atomic_data: dict) -> Tuple[Any, Any]: Tuple of (H_blocks, S_blocks). S_blocks can be None. """ pass + + @abstractmethod + def get_hR(self, atomic_data: dict) -> Tuple[Any, Any]: + """ + Get the Hamiltonian (and Overlap) as vbcsr.ImageContainer. + + Args: + atomic_data: The input atomic data. + + Returns: + Tuple of (H_container, S_container). S_container can be None. + """ + pass @abstractmethod def get_eigenvalues(self, atomic_data: dict) -> Tuple[dict, torch.Tensor]: @@ -136,13 +154,35 @@ def model_forward(self, atomic_data: dict) -> dict: # Run model forward pass to get H/S blocks atomic_data = self.model(atomic_data) + + """ + Overlap logic: + overlap sources: a. the model inference b. the user defined overlap files c. the two center integrals + priority: b > c > a + """ # Restore overlap if it was an override # We only need to do this if the model actually has overlap capability (and thus might have overwritten it) if self.overlap and override_edge is not None: atomic_data[AtomicDataDict.EDGE_OVERLAP_KEY] = override_edge if override_node is not None: - atomic_data[AtomicDataDict.NODE_OVERLAP_KEY] = override_node + atomic_data[AtomicDataDict.NODE_OVERLAP_KEY] = override_node + elif self.overlap and hasattr(self.model, "orbital_files_content"): + from dptb.postprocess.ovp2c import compute_overlap + # write the orbital files content to a temporary file + orb_names = [] + for sym in self.model.idp.type_names: + with open(f"./temp_{sym}.orb", "w") as f: + f.write(self.model.orbital_files_content[sym]) + orb_names.append(f"temp_{sym}.orb") + atomic_data = compute_overlap(atomic_data, self.model.idp, "./", orb_names) + # remove the temporary files + for orb_name in orb_names: + os.remove(orb_name) + elif self.overlap: + if hasattr(self.model, "method") and getattr(self.model, "method", None) == "e3tb": + log.warning("The overlap inferenced from model is not stable in singular basis, please ensure this is what you want.") + return atomic_data def get_hr(self, atomic_data): @@ -155,6 +195,33 @@ def get_hr(self, atomic_data): return Hblocks, Sblocks + def get_hR(self, atomic_data): + # Initialize hR converters + h2R = Hr2HR( + idp=self.model.idp, + edge_field=AtomicDataDict.EDGE_FEATURES_KEY, + node_field=AtomicDataDict.NODE_FEATURES_KEY, + overlap=False, + dtype=self.model.dtype, + device=self.device + ) + if self.overlap: + s2R = Hr2HR( + idp=self.model.idp, + edge_field=AtomicDataDict.EDGE_OVERLAP_KEY, + node_field=AtomicDataDict.NODE_OVERLAP_KEY, + overlap=True, + dtype=self.model.dtype, + device=self.device + ) + atomic_data = self.model_forward(atomic_data) + h_container = h2R(atomic_data) + if self.overlap: + s_container = s2R(atomic_data) + else: + s_container = None + return h_container, s_container + def get_eigenvalues(self, atomic_data: dict, nk: Optional[int]=None, diff --git a/dptb/postprocess/unified/system.py b/dptb/postprocess/unified/system.py index 55d75a0d..188cf64f 100644 --- a/dptb/postprocess/unified/system.py +++ b/dptb/postprocess/unified/system.py @@ -326,6 +326,15 @@ def get_dos(self, kmesh: Optional[Union[list,np.ndarray]] = None, is_gamma_cente self.has_dos = True return self._dos + def get_hR(self): + """ + Get the Hamiltonian (and Overlap) as vbcsr.ImageContainer. + + Returns: + Tuple of (H_container, S_container). S_container can be None. + """ + return self._calculator.get_hR(self._atomic_data) + def to_pardiso(self, output_dir: Optional[str] = "pardiso_input"): """ diff --git a/dptb/tests/test_energy_feast.py b/dptb/tests/test_energy_feast.py new file mode 100644 index 00000000..954857e1 --- /dev/null +++ b/dptb/tests/test_energy_feast.py @@ -0,0 +1,107 @@ + +import pytest +import numpy as np +import scipy.sparse as sp +from scipy.linalg import eigh +from dptb.nn.energy import FEASTEig +from dptb.utils.feast_wrapper import _MKL_RT + +class MockMat: + def __init__(self, mat): + self.mat = mat + def to_scipy(self, format="csr"): + return self.mat + +class MockContainer: + def __init__(self, mat): + self.obj = MockMat(mat) + def sample_k(self, k, symm=True): + return self.obj + +@pytest.mark.skipif(_MKL_RT is None, reason="MKL runtime not found") +class TestFEASTEig: + def test_solve_standard(self): + N = 50 + np.random.seed(42) + A = np.random.rand(N, N) + A = A + A.T + Asp = sp.csr_matrix(A) + + h_container = MockContainer(Asp) + kpoints = np.array([[0,0,0]]) # Dummy kpoint + + evals_ref = eigh(A, eigvals_only=True) + emin, emax = evals_ref[0]-0.1, evals_ref[-1]+0.1 + + solver = FEASTEig(emin=emin, emax=emax, m0=N) + evals_list = solver.solve(h_container, None, kpoints) + + assert len(evals_list) == 1 + evals = evals_list[0] + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + def test_solve_generalized(self): + N = 30 + np.random.seed(43) + A = np.random.rand(N, N) + A = A + A.T + M = np.random.rand(N, N) + M = M @ M.T + np.eye(N) + + Asp = sp.csr_matrix(A) + Msp = sp.csr_matrix(M) + + h_container = MockContainer(Asp) + s_container = MockContainer(Msp) + kpoints = np.array([[0,0,0]]) + + evals_ref = eigh(A, M, eigvals_only=True) + emin, emax = np.min(evals_ref)-0.1, np.max(evals_ref)+0.1 + + solver = FEASTEig(emin=emin, emax=emax, m0=N) + evals_list = solver.solve(h_container, s_container, kpoints) + + assert len(evals_list) == 1 + evals = evals_list[0] + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + def test_solve_with_vectors(self): + """Test returning eigenvectors.""" + N = 40 + np.random.seed(44) + A = np.random.rand(N, N) + A = A + A.T + Asp = sp.csr_matrix(A) + + h_container = MockContainer(Asp) + kpoints = np.array([[0,0,0]]) + + # Reference evals + evals_ref = eigh(A, eigvals_only=True) + emin, emax = evals_ref[0]-0.1, evals_ref[-1]+0.1 + + solver = FEASTEig(emin=emin, emax=emax, m0=N) + + # Test return_eigenvectors=True + evals_list, evecs_list = solver.solve(h_container, None, kpoints, return_eigenvectors=True) + + assert len(evals_list) == 1 + assert len(evecs_list) == 1 + + evals = evals_list[0] + evecs = evecs_list[0] + + assert len(evals) == N + assert evecs.shape == (N, N) # Found all + + # Check residual for first few + for i in range(min(5, N)): + val = evals[i] + vec = evecs[:, i] + resid = Asp @ vec - val * vec + assert np.linalg.norm(resid) < 1e-8 + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/dptb/tests/test_feast_wrapper.py b/dptb/tests/test_feast_wrapper.py new file mode 100644 index 00000000..b54811d2 --- /dev/null +++ b/dptb/tests/test_feast_wrapper.py @@ -0,0 +1,196 @@ + +import pytest +import numpy as np +import scipy.sparse as sp +from scipy.linalg import eigh +from dptb.utils.feast_wrapper import FeastSolver, _MKL_RT + +@pytest.mark.skipif(_MKL_RT is None, reason="MKL runtime not found") +class TestFeastWrapper: + + def test_standard_hermitian(self): + """Test finding eigenvalues in interval for standard Hermitian problem.""" + N = 100 + np.random.seed(42) + # Create random Hermitian matrix + A = np.random.rand(N, N) + 1j * np.random.rand(N, N) + A = A + A.T.conj() + # Make sparse + A[np.abs(A) < 0.8] = 0 + Asp = sp.csr_matrix(A) + + # Reference solution (dense) + evals_ref, evecs_ref = eigh(A) + + # Define interval covering middle 10 eigenvalues + mid = N // 2 + emin = evals_ref[mid-5] - 0.1 + emax = evals_ref[mid+5] + 0.1 + expected_indices = np.where((evals_ref >= emin) & (evals_ref <= emax))[0] + n_expected = len(expected_indices) + + solver = FeastSolver() + # Initial guess explicitly smaller to test resizing logic if needed, + # but here we use conservative m0=20 > 11 so it should work first try. + evals, X = solver.solve(Asp, emin=emin, emax=emax, m0=max(n_expected + 5, 20)) + + assert len(evals) == n_expected + np.testing.assert_allclose(np.sort(evals), evals_ref[expected_indices], atol=1e-8) + + # Check eigenvector residual ||Ax - \lambda x|| + for i in range(len(evals)): + val = evals[i] + vec = X[:, i] + resid = Asp @ vec - val * vec + assert np.linalg.norm(resid) < 1e-8 + + def test_generalized_hermitian(self): + """Test generalized problem Ax = \lambda Mx where M is positive definite.""" + N = 50 + np.random.seed(123) + + # A: Hermitian + A = np.random.rand(N, N) + 1j * np.random.rand(N, N) + A = A + A.T.conj() + A[np.abs(A) < 0.5] = 0 + Asp = sp.csr_matrix(A) + + # M: Hermitian Positive Definite + M = np.random.rand(N, N) + 1j * np.random.rand(N, N) + M = M @ M.T.conj() + np.eye(N) # Ensure pos def + M[np.abs(M) < 0.1] = 0 + Msp = sp.csr_matrix(M) + + # Dense reference: eigh(a, b) solves generalized + evals_ref = eigh(A, M, eigvals_only=True) + + # Interval + emin, emax = np.min(evals_ref) - 0.1, np.max(evals_ref) + 0.1 + + solver = FeastSolver() + evals, X = solver.solve(Asp, M=Msp, emin=emin, emax=emax) + + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + # Check generalized residual ||Ax - \lambda Mx|| + for i in range(len(evals)): + val = evals[i] + vec = X[:, i] + resid = Asp @ vec - val * (Msp @ vec) + assert np.linalg.norm(resid) < 1e-7 + + def test_subspace_resize(self): + """Test if wrapper correctly handles small initial m0 (info=3).""" + N = 50 + np.random.seed(999) + A = np.random.rand(N, N) + A = A + A.T + Asp = sp.csr_matrix(A) + + evals_ref = eigh(A, eigvals_only=True) + emin, emax = -100, 100 # All eigenvalues + + solver = FeastSolver() + # Start with ridiculously small m0=2 + # Expected eigenvalues = 50. + # FEAST should return info=3 and wrapper should retry. + + # Capture stdout to see print message? Or just verify result. + evals, X = solver.solve(Asp, emin=emin, emax=emax, m0=2, max_refinement=10) + + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + + def test_real_symmetric(self): + """Test finding eigenvalues for real symmetric problem.""" + N = 100 + np.random.seed(42) + A = np.random.rand(N, N) + A = A + A.T + Asp = sp.csr_matrix(A) + + evals_ref = eigh(A, eigvals_only=True) + emin, emax = evals_ref[0]-0.1, evals_ref[-1]+0.1 + + solver = FeastSolver() + evals, X = solver.solve(Asp, emin=emin, emax=emax, m0=max(N, 20)) + + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + # Check eigenvector residual + for i in range(len(evals)): + val = evals[i] + vec = X[:, i] + resid = Asp @ vec - val * vec + assert np.linalg.norm(resid) < 1e-8 + + def test_generalized_real_symmetric(self): + """Test generalized real symmetric Ax = lambda Mx.""" + N = 50 + np.random.seed(123) + A = np.random.rand(N, N) + A = A + A.T + Asp = sp.csr_matrix(A) + + M = np.random.rand(N, N) + M = M @ M.T + np.eye(N) + Msp = sp.csr_matrix(M) + + evals_ref = eigh(A, M, eigvals_only=True) + emin, emax = np.min(evals_ref)-0.1, np.max(evals_ref)+0.1 + + solver = FeastSolver() + evals, X = solver.solve(Asp, M=Msp, emin=emin, emax=emax, m0=N, max_refinement=10) + + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + + for i in range(len(evals)): + val = evals[i] + vec = X[:, i] + resid = Asp @ vec - val * (Msp @ vec) + assert np.linalg.norm(resid) < 1e-7 + + def test_lower_triangular(self): + """Test with uplo='L' and automatic extraction.""" + N = 50 + np.random.seed(99) + A = np.random.rand(N, N) + A = A + A.T + Asp = sp.csr_matrix(A) + + evals_ref = eigh(A, eigvals_only=True) + emin, emax = np.min(evals_ref)-0.1, np.max(evals_ref)+0.1 + + solver = FeastSolver() + # Find ALL eigenvalues + evals, X = solver.solve(Asp, emin=emin, emax=emax, m0=N, uplo='L', extract_triangular=True) + + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + + def test_manual_triangular(self): + """Test with pre-processed triangular matrix and extract_triangular=False.""" + N = 50 + np.random.seed(101) + A = np.random.rand(N, N) + A = A + A.T + # Manually extract upper + A_triu = sp.triu(A, format='csr') + + evals_ref = eigh(A, eigvals_only=True) + emin, emax = np.min(evals_ref)-0.1, np.max(evals_ref)+0.1 + + solver = FeastSolver() + # Pass triangular matrix, disable extraction, set uplo='U' + evals, X = solver.solve(A_triu, emin=emin, emax=emax, m0=N, uplo='U', extract_triangular=False) + + assert len(evals) == N + np.testing.assert_allclose(np.sort(evals), np.sort(evals_ref), atol=1e-8) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/dptb/tests/test_hr2hr.py b/dptb/tests/test_hr2hr.py new file mode 100644 index 00000000..e4155236 --- /dev/null +++ b/dptb/tests/test_hr2hr.py @@ -0,0 +1,42 @@ +from dptb.postprocess.bandstructure.band import Band +from dptb.utils.tools import j_loader +from dptb.nn.build import build_model +from dptb.data import build_dataset, AtomicData, AtomicDataDict +from dptb.nn.hr2hR import Hr2HR +import torch +import pytest + + +@pytest.fixture(scope='session', autouse=True) +def root_directory(request): + return str(request.config.rootdir) + +def test_hr2hr(root_directory): + # build the trained e3_band hamiltonian and overlap model + model = build_model(checkpoint=f"{root_directory}/dptb/tests/data/e3_band/ref_model/nnenv.ep1474.pth") + + dataset = build_dataset.from_model( + model=model, + root=f"{root_directory}/dptb/tests/data/e3_band/data/", + prefix="Si64" + ) + + adata = dataset[0] + adata = AtomicData.to_AtomicDataDict(adata) + + hr2hr = Hr2HR( + idp=model.idp, + edge_field=AtomicDataDict.EDGE_FEATURES_KEY, + node_field=AtomicDataDict.NODE_FEATURES_KEY, + overlap=False, + dtype=torch.float32, + device=torch.device("cpu") + ) + + adata = model(adata) + image_c = hr2hr(adata) + vb = image_c.sample_k([0,0,0]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/dptb/tests/test_orbital_parser.py b/dptb/tests/test_orbital_parser.py new file mode 100644 index 00000000..9da421ba --- /dev/null +++ b/dptb/tests/test_orbital_parser.py @@ -0,0 +1,42 @@ +import unittest +import os +from dptb.utils.orbital_parser import parse_orbital_file +from unittest.mock import patch, mock_open + +class TestOrbitalParser(unittest.TestCase): + def test_parse_valid_file(self): + content = """ +Element Si +Number of Sorbital--> 2 +Number of Porbital--> 2 +Number of Dorbital--> 1 +""" + with patch("builtins.open", mock_open(read_data=content)): + with patch("os.path.exists", return_value=True): + basis = parse_orbital_file("dummy.orb") + self.assertEqual(basis, "2s2p1d") + + def test_parse_zero_count(self): + content = """ +Number of Sorbital--> 1 +Number of Porbital--> 0 +""" + with patch("builtins.open", mock_open(read_data=content)): + with patch("os.path.exists", return_value=True): + basis = parse_orbital_file("dummy.orb") + self.assertEqual(basis, "1s") + + def test_parse_invalid_file(self): + content = "Invalid content" + with patch("builtins.open", mock_open(read_data=content)): + with patch("os.path.exists", return_value=True): + with self.assertRaisesRegex(ValueError, "No valid orbital counts found"): + parse_orbital_file("dummy.orb") + + def test_file_not_found(self): + with patch("os.path.exists", return_value=False): + with self.assertRaises(FileNotFoundError): + parse_orbital_file("nonexistent.orb") + +if __name__ == '__main__': + unittest.main() diff --git a/dptb/tests/test_pardisoeig.py b/dptb/tests/test_pardisoeig.py new file mode 100644 index 00000000..dbda533f --- /dev/null +++ b/dptb/tests/test_pardisoeig.py @@ -0,0 +1,76 @@ +from dptb.postprocess.bandstructure.band import Band +from dptb.utils.tools import j_loader +from dptb.nn.build import build_model +from dptb.data import build_dataset, AtomicData, AtomicDataDict +from dptb.nn.hr2hR import Hr2HR +from dptb.nn.energy import PardisoEig +import torch +import pytest +from dptb.utils.pardiso_wrapper import _MKL_RT_HANDLE + + + +@pytest.fixture(scope='session', autouse=True) +def root_directory(request): + return str(request.config.rootdir) + +@pytest.mark.skipif(_MKL_RT_HANDLE is None, reason="MKL runtime not found") +def test_hr2hr(root_directory): + # build the trained e3_band hamiltonian and overlap model + model = build_model(checkpoint=f"{root_directory}/dptb/tests/data/e3_band/ref_model/nnenv.ep1474.pth") + + dataset = build_dataset.from_model( + model=model, + root=f"{root_directory}/dptb/tests/data/e3_band/data/", + prefix="Si64" + ) + + adata = dataset[0] + adata = AtomicData.to_AtomicDataDict(adata) + + hr2hr = Hr2HR( + idp=model.idp, + edge_field=AtomicDataDict.EDGE_FEATURES_KEY, + node_field=AtomicDataDict.NODE_FEATURES_KEY, + overlap=False, + dtype=torch.float32, + device=torch.device("cpu") + ) + + sr2sr = Hr2HR( + idp=model.idp, + edge_field=AtomicDataDict.EDGE_OVERLAP_KEY, + node_field=AtomicDataDict.NODE_OVERLAP_KEY, + overlap=True, + dtype=torch.float32, + device=torch.device("cpu") + ) + + adata = model(adata) + image_h = hr2hr(adata) + image_s = sr2sr(adata) + + # h = image_h.sample_k([0,0,0], symm=True).to_scipy(format="csr") + + peig = PardisoEig( + sigma=0.0, + neig=10, + mode='normal' + ) + + + peig.solve(image_h, image_s, [[0,0,0],[1,1,1],[0.5,0.5,0.5]]) + + # Test eigenvector return + evals, evecs = peig.solve(image_h, image_s, [[0,0,0]], return_eigenvectors=True) + assert len(evals) == 1 + assert len(evecs) == 1 + + # Check shape + N = image_h.sample_k([0,0,0], symm=True).to_scipy(format="csr").shape[0] + assert evecs[0].shape == (N, 10) + + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/dptb/tests/test_postprocess_hr.py b/dptb/tests/test_postprocess_hr.py new file mode 100644 index 00000000..fcd88e8d --- /dev/null +++ b/dptb/tests/test_postprocess_hr.py @@ -0,0 +1,47 @@ +import sys +import os +import unittest +# from unittest.mock import MagicMock, patch + +# Mock vbcsr before importing dptb to avoid crash if vbcsr is not installed or broken +# sys.modules["vbcsr"] = MagicMock() + +import torch +from dptb.postprocess.unified.system import TBSystem +from dptb.postprocess.unified.calculator import HamiltonianCalculator, DeePTBAdapter +from dptb.data import AtomicDataDict + +# @patch('dptb.nn.hr2hR.ImageContainer') +# @patch('dptb.nn.hr2hR.AtomicData_vbcsr') +class TestPostProcessHR(unittest.TestCase): + def setUp(self): + self.root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + self.model_path = os.path.join(self.root, "tests/data/e3_band/ref_model/nnenv.ep1474.pth") + self.struc_path = os.path.join(self.root, "tests/data/e3_band/data/Si64.vasp") + + def test_get_hR_real_model(self): + # Initialize System with real model + system = TBSystem( + data=self.struc_path, + calculator=self.model_path, + device='cpu' + ) + + # Call get_hR + h, s = system.get_hR() + + if s is not None: + sK = s.sample_k([0,0,0], symm=True) + # print(sK.to_dense()) + self.assertIsNotNone(sK) + + # Verification + # The model in test_to_pardiso might or might not have overlap. + # Based on name "nnsk.iter_ovp0.000.pth", it likely does? Or maybe not. + # Let's check if s is None or not based on model config. + # But at least the code ran and returned what Hr2HR produced. + + + +if __name__ == '__main__': + unittest.main() diff --git a/dptb/tests/test_train_orb.py b/dptb/tests/test_train_orb.py new file mode 100644 index 00000000..d176c3f5 --- /dev/null +++ b/dptb/tests/test_train_orb.py @@ -0,0 +1,107 @@ +import unittest +from unittest.mock import patch, MagicMock +import os +import tempfile +import sys +from dptb.entrypoints.train import train +train_module = sys.modules['dptb.entrypoints.train'] + +class TestTrainOrbitalIntegration(unittest.TestCase): + + @patch.object(train_module, 'j_loader') + @patch.object(train_module, 'normalize') + @patch.object(train_module, 'build_dataset') + @patch.object(train_module, 'build_model') + @patch.object(train_module, 'Trainer') + @patch.object(train_module, 'set_log_handles') + @patch.object(train_module, 'collect_cutoffs') + @patch.object(train_module, 'setup_seed') + @patch('os.makedirs') # Mock makedirs to prevent creating directories + @patch('pathlib.Path.mkdir') + def test_train_orbital_parsing(self, mock_mkdir, mock_makedirs, mock_setup_seed, mock_collect_cutoffs, mock_set_log, mock_trainer, mock_build_model, mock_build_dataset, mock_normalize, mock_j_loader): + # Use the specific orbital file requested + orb_file_path = "./dptb/tests/data/e3_band/data/Si_gga_7au_100Ry_2s2p1d.orb" + + # Verify file exists before running test + if not os.path.exists(orb_file_path): + self.skipTest(f"Test file not found: {orb_file_path}") + + try: + # Mock j_loader to return our config + mock_config = { + "common_options": { + "basis": {"Si": orb_file_path}, + "dtype": "float32", + "seed": 123 + }, + "model_options": { + "prediction": {"method": "e3tb"} + }, + "data_options": {"train": {}}, + "train_options": {} + } + mock_j_loader.return_value = mock_config + mock_normalize.return_value = mock_config # Assume normalize returns it as is for this test + + # Run train + # passing output=None prevents file creation attempts + try: + # We expect it to eventually crash or finish. + # Since we mocked build_model etc, it might proceed until it tries to use the mocked objects. + # However, the orbital parsing happens early. + # raising an exception in build_model allows us to stop execution after parsing + mock_build_model.side_effect = InterruptedError("Verify point reached") + + train(INPUT="dummy.json", init_model=None, restart=None, output=None, log_level=20, log_path=None) + except InterruptedError: + pass + except Exception as e: + self.fail(f"Train raised unexpected exception: {e}") + + # Verify jdata was modified + # Since jdata is a local variable in train, we can't inspect it directly. + # But we can inspect the calls to build_model or build_dataset, which receive jdata components. + + # Check build_model call args + args, kwargs = mock_build_model.call_args + common_options = kwargs.get('common_options', {}) + + # 1. Check basis string + self.assertEqual(common_options['basis']['Si'], '2s2p1d') + + # 2. Check orbital_files_content + self.assertIn('orbital_files_content', common_options) + self.assertIn('Si', common_options['orbital_files_content']) + # Verify some content from the real file + self.assertIn("Number of Sorbital--> 2", common_options['orbital_files_content']['Si']) + + finally: + pass + + @patch.object(train_module, 'j_loader') + @patch.object(train_module, 'normalize') + def test_train_orbital_parsing_wrong_method(self, mock_normalize, mock_j_loader): + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.orb') as tmp_orb: + tmp_orb_path = tmp_orb.name + try: + mock_config = { + "common_options": { + "basis": {"Si": tmp_orb_path}, + "dtype": "float32" + }, + "model_options": { + "prediction": {"method": "sktb"} # Not e3tb + } + } + mock_j_loader.return_value = mock_config + mock_normalize.return_value = mock_config + + with self.assertRaisesRegex(ValueError, "only supported for the 'e3tb' method"): + train(INPUT="dummy.json", init_model=None, restart=None, output=None, log_level=20, log_path=None) + + finally: + if os.path.exists(tmp_orb_path): + os.remove(tmp_orb_path) + +if __name__ == '__main__': + unittest.main() diff --git a/dptb/utils/feast_wrapper.py b/dptb/utils/feast_wrapper.py new file mode 100644 index 00000000..c5eb7c91 --- /dev/null +++ b/dptb/utils/feast_wrapper.py @@ -0,0 +1,354 @@ +import ctypes +import os +import sys +import glob +import site +import numpy as np +import scipy.sparse as sp +from ctypes.util import find_library +import warnings + +# Use same MKL finding logic as pardiso_wrapper for consistency +def _find_mkl_rt(): + """Find and load mkl_rt shared library.""" + libmkl = None + mkl_rt = os.environ.get('PYPARDISO_MKL_RT') + if mkl_rt is None: + mkl_rt = find_library('mkl_rt') + if mkl_rt is None: + mkl_rt = find_library('mkl_rt.1') + + if mkl_rt is None: + globs = glob.glob(f'{sys.prefix}/[Ll]ib*/**/*mkl_rt*', recursive=True) or \ + glob.glob(f'{site.USER_BASE}/[Ll]ib*/**/*mkl_rt*', recursive=True) + for path in sorted(globs, key=len): + try: + libmkl = ctypes.CDLL(path) + break + except (OSError, ImportError): + pass + else: + try: + libmkl = ctypes.CDLL(mkl_rt) + except (OSError, ImportError): + pass + return libmkl + +_MKL_RT = _find_mkl_rt() + +if _MKL_RT: + # Define function signatures for FEAST + + # void feastinit(int* fpm); + _MKL_RT.feastinit.argtypes = [ctypes.POINTER(ctypes.c_int)] + _MKL_RT.feastinit.restype = None + + # void zfeast_hcsrev( + # const char* uplo, const int* n, const void* a, const int* ia, const int* ja, + # const int* fpm, double* epsout, int* loop, const double* emin, const double* emax, + # int* m0, double* E, void* X, int* M, double* res, int* info + # ); + _MKL_RT.zfeast_hcsrev.argtypes = [ + ctypes.POINTER(ctypes.c_char), # uplo + ctypes.POINTER(ctypes.c_int), # n + ctypes.c_void_p, # a (complex128*) + ctypes.POINTER(ctypes.c_int), # ia + ctypes.POINTER(ctypes.c_int), # ja + ctypes.POINTER(ctypes.c_int), # fpm + ctypes.POINTER(ctypes.c_double),# epsout + ctypes.POINTER(ctypes.c_int), # loop + ctypes.POINTER(ctypes.c_double),# emin + ctypes.POINTER(ctypes.c_double),# emax + ctypes.POINTER(ctypes.c_int), # m0 + ctypes.c_void_p, # E (double*) + ctypes.c_void_p, # X (complex128*) + ctypes.POINTER(ctypes.c_int), # M (found eigs) + ctypes.POINTER(ctypes.c_double),# res + ctypes.POINTER(ctypes.c_int) # info + ] + _MKL_RT.zfeast_hcsrev.restype = None + + # void zfeast_hcsrgv(...) see below + + # void dfeast_scsrev( + # const char* uplo, const int* n, const double* a, const int* ia, const int* ja, + # const int* fpm, double* epsout, int* loop, const double* emin, const double* emax, + # int* m0, double* E, double* X, int* M, double* res, int* info + # ); + _MKL_RT.dfeast_scsrev.argtypes = [ + ctypes.POINTER(ctypes.c_char), # uplo + ctypes.POINTER(ctypes.c_int), # n + ctypes.c_void_p, # a (double*) + ctypes.POINTER(ctypes.c_int), # ia + ctypes.POINTER(ctypes.c_int), # ja + ctypes.POINTER(ctypes.c_int), # fpm + ctypes.POINTER(ctypes.c_double),# epsout + ctypes.POINTER(ctypes.c_int), # loop + ctypes.POINTER(ctypes.c_double),# emin + ctypes.POINTER(ctypes.c_double),# emax + ctypes.POINTER(ctypes.c_int), # m0 + ctypes.c_void_p, # E (double*) + ctypes.c_void_p, # X (double*) + ctypes.POINTER(ctypes.c_int), # M (found eigs) + ctypes.POINTER(ctypes.c_double),# res + ctypes.POINTER(ctypes.c_int) # info + ] + _MKL_RT.dfeast_scsrev.restype = None + + # void dfeast_scsrgv( + # const char* uplo, const int* n, const double* a, const int* ia, const int* ja, + # const double* b, const int* ib, const int* jb, + # const int* fpm, double* epsout, int* loop, const double* emin, const double* emax, + # int* m0, double* E, double* X, int* M, double* res, int* info + # ); + _MKL_RT.dfeast_scsrgv.argtypes = [ + ctypes.POINTER(ctypes.c_char), # uplo + ctypes.POINTER(ctypes.c_int), # n + ctypes.c_void_p, # a (double*) + ctypes.POINTER(ctypes.c_int), # ia + ctypes.POINTER(ctypes.c_int), # ja + ctypes.c_void_p, # b (double*) + ctypes.POINTER(ctypes.c_int), # ib + ctypes.POINTER(ctypes.c_int), # jb + ctypes.POINTER(ctypes.c_int), # fpm + ctypes.POINTER(ctypes.c_double),# epsout + ctypes.POINTER(ctypes.c_int), # loop + ctypes.POINTER(ctypes.c_double),# emin + ctypes.POINTER(ctypes.c_double),# emax + ctypes.POINTER(ctypes.c_int), # m0 + ctypes.c_void_p, # E (double*) + ctypes.c_void_p, # X (double*) + ctypes.POINTER(ctypes.c_int), # M (found eigs) + ctypes.POINTER(ctypes.c_double),# res + ctypes.POINTER(ctypes.c_int) # info + ] + _MKL_RT.dfeast_scsrgv.restype = None + + # void zfeast_hcsrgv( + # const char* uplo, const int* n, const void* a, const int* ia, const int* ja, + # const void* b, const int* ib, const int* jb, + # const int* fpm, double* epsout, int* loop, const double* emin, const double* emax, + # int* m0, double* E, void* X, int* M, double* res, int* info + # ); + _MKL_RT.zfeast_hcsrgv.argtypes = [ + ctypes.POINTER(ctypes.c_char), # uplo + ctypes.POINTER(ctypes.c_int), # n + ctypes.c_void_p, # a (complex128*) + ctypes.POINTER(ctypes.c_int), # ia + ctypes.POINTER(ctypes.c_int), # ja + ctypes.c_void_p, # b (complex128*) + ctypes.POINTER(ctypes.c_int), # ib + ctypes.POINTER(ctypes.c_int), # jb + ctypes.POINTER(ctypes.c_int), # fpm + ctypes.POINTER(ctypes.c_double),# epsout + ctypes.POINTER(ctypes.c_int), # loop + ctypes.POINTER(ctypes.c_double),# emin + ctypes.POINTER(ctypes.c_double),# emax + ctypes.POINTER(ctypes.c_int), # m0 + ctypes.c_void_p, # E (double*) + ctypes.c_void_p, # X (complex128*) + ctypes.POINTER(ctypes.c_int), # M (found eigs) + ctypes.POINTER(ctypes.c_double),# res + ctypes.POINTER(ctypes.c_int) # info + ] + _MKL_RT.zfeast_hcsrgv.restype = None + +class FeastSolver: + """ + Wrapper for MKL FEAST solver (zfeast_hcsrev) for complex Hermitian matrices. + Finds all eigenvalues in a given interval [emin, emax]. + """ + + def __init__(self): + if _MKL_RT is None: + raise ImportError("MKL runtime library (mkl_rt) not found. Cannot use FEAST.") + + # Initialize default FPM + self.fpm = np.zeros(128, dtype=np.int32) + _MKL_RT.feastinit(self.fpm.ctypes.data_as(ctypes.POINTER(ctypes.c_int))) + + # Standard defaults suitable for contour integration + # fpm[0] = 1 (Enable logging/print) - Turn off by default + self.fpm[0] = 0 + + def _prepare_matrix(self, mat, dtype, uplo_char, extract_triangular, name="Matrix"): + """Prepare matrix: ensure CSR, check dtype, extract triangular part.""" + # Check dtype + if mat.dtype != dtype: + warnings.warn(f"Converting {name} to {dtype} for FEAST") + mat = mat.astype(dtype) + + # Ensure CSR + if not sp.isspmatrix_csr(mat): + mat = mat.tocsr() + + if extract_triangular: + if uplo_char == b'U': + return sp.triu(mat, format='csr') + elif uplo_char == b'L': + return sp.tril(mat, format='csr') + else: + raise ValueError(f"Invalid uplo: {uplo_char}") + else: + return mat + + def solve(self, A, M=None, emin=-1.0, emax=1.0, m0=None, max_refinement=3, uplo='U', extract_triangular=True): + """ + Solve eigenvalue problem Ax = \lambda x (or Ax = \lambda Mx) for \lambda in [emin, emax]. + + Args: + A: Scipy sparse CSR matrix (Complex Hermitian / Real Symmetric) + M: Scipy sparse CSR matrix (Hermitian Positive Definite), optional. + emin, emax: Eigenvalue interval. + m0: Initial subspace size. If None, defaults to 10 or 1.5x expected if passed. + max_refinement: Number of retries if subspace is too small (info=3). + uplo: Upper ('U') or Lower ('L') triangle to use. Default 'U'. + extract_triangular: If True (default), automatically extracts the specified triangular part + using sp.triu/sp.tril. Set False if input is already triangular. + + Returns: + evals: Array of found eigenvalues. + X: Array of eigenvectors (column-wise). + """ + if not sp.isspmatrix_csr(A): + A = A.tocsr() + + N = A.shape[0] + if A.shape[1] != N: + raise ValueError("Matrix A must be square") + + # Check dtype + # Prepare pointers + if isinstance(uplo, str): + uplo_char = uplo.upper().encode('ascii') + elif isinstance(uplo, bytes): + uplo_char = uplo.upper() + else: + uplo_char = b'U' + + uplo_c = ctypes.create_string_buffer(uplo_char) + + # Detect Complexity + is_complex = np.iscomplexobj(A) or (M is not None and np.iscomplexobj(M)) + + if is_complex: + # Complex Hermitian + dtype = np.complex128 + fn_std = _MKL_RT.zfeast_hcsrev + fn_gen = _MKL_RT.zfeast_hcsrgv + else: + # Real Symmetric + dtype = np.float64 + fn_std = _MKL_RT.dfeast_scsrev + fn_gen = _MKL_RT.dfeast_scsrgv + + # Prepare A + A_triu = self._prepare_matrix(A, dtype, uplo_char, extract_triangular, name="A") + + # Prepare M (if generalized) + if M is not None: + M_triu = self._prepare_matrix(M, dtype, uplo_char, extract_triangular, name="M") + + if m0 is None: + # If finding ALL eigenvalues (emin very small, emax very large), m0 should span the space. + # But usually we find interval. Conservative estimate: 1.5x expected or somewhat large number. + # For robustness, start with larger default if N small. + m0 = min(N, max(N // 2, 20)) + + # Generic preparation + ia = A_triu.indptr.astype(np.int32) + 1 + ja = A_triu.indices.astype(np.int32) + 1 + a_data = A_triu.data + + + loop = ctypes.c_int(0) + epsout = ctypes.c_double(0.0) + emin_c = ctypes.c_double(emin) + emax_c = ctypes.c_double(emax) + + # Retry loop for m0 refinement + for attempt in range(max_refinement + 1): + + # Prepare output arrays + E = np.zeros(m0, dtype=np.float64) + # MKL expects column-major (Fortran) storage for dense matrices. + X = np.zeros((N, m0), dtype=dtype, order='F') + res = np.zeros(m0, dtype=np.float64) + info = ctypes.c_int(0) + M_found = ctypes.c_int(0) + m0_c = ctypes.c_int(m0) + + if M is None: + # Standard problem + fn_std( + uplo_c, + ctypes.byref(ctypes.c_int(N)), + a_data.ctypes.data_as(ctypes.c_void_p), + ia.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + ja.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + self.fpm.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + ctypes.byref(epsout), + ctypes.byref(loop), + ctypes.byref(emin_c), + ctypes.byref(emax_c), + ctypes.byref(m0_c), + E.ctypes.data_as(ctypes.c_void_p), + X.ctypes.data_as(ctypes.c_void_p), + ctypes.byref(M_found), + res.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + ctypes.byref(info) + ) + else: + # Generalized problem + m_data = M_triu.data + ib = M_triu.indptr.astype(np.int32) + 1 + jb = M_triu.indices.astype(np.int32) + 1 + + fn_gen( + uplo_c, + ctypes.byref(ctypes.c_int(N)), + a_data.ctypes.data_as(ctypes.c_void_p), + ia.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + ja.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + m_data.ctypes.data_as(ctypes.c_void_p), + ib.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + jb.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + self.fpm.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + ctypes.byref(epsout), + ctypes.byref(loop), + ctypes.byref(emin_c), + ctypes.byref(emax_c), + ctypes.byref(m0_c), + E.ctypes.data_as(ctypes.c_void_p), + X.ctypes.data_as(ctypes.c_void_p), + ctypes.byref(M_found), + res.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + ctypes.byref(info) + ) + + if info.value == 0: + # Success + n_eig = M_found.value + return E[:n_eig], X[:, :n_eig] + + elif info.value == 3: + # Warning: Size of the subspace M0 is too small + # MKL usually suggests a new size in m0_c + new_m0 = m0_c.value + if new_m0 <= m0: # If suggestion is not larger, force doubling + new_m0 = m0 * 2 + + if new_m0 > N: + new_m0 = N + if m0 == N: + raise RuntimeError(f"FEAST info=3: Subspace too small even at N={N}.") + + print(f"FEAST info=3 (Subspace too small/bad estimate). MKL suggested {m0_c.value}. Increasing m0 from {m0} to {new_m0} and retrying.") + m0 = new_m0 + continue + + else: + # Other errors + raise RuntimeError(f"FEAST failed with info={info.value}. Check MKL documentation.") + + raise RuntimeError(f"FEAST failed to converge after {max_refinement} refinements of m0.") diff --git a/dptb/utils/orbital_parser.py b/dptb/utils/orbital_parser.py new file mode 100644 index 00000000..006fb65f --- /dev/null +++ b/dptb/utils/orbital_parser.py @@ -0,0 +1,51 @@ +import re +import os + +def parse_orbital_file(filepath: str) -> str: + """ + Parses an orbital file to extract the basis set definition string (e.g., "2s2p1d"). + + The function looks for lines in the format: + "Number of [OrbitalType]orbital--> [Count]" + + Args: + filepath (str): The path to the orbital file. + + Returns: + str: The basis string representing the orbital counts (e.g., "2s2p1d"). + + Raises: + ValueError: If the file content doesn't match the expected format or no orbitals are found. + FileNotFoundError: If the file does not exist. + """ + if not os.path.exists(filepath): + raise FileNotFoundError(f"Orbital file not found: {filepath}") + + with open(filepath, 'r') as f: + content = f.read() + + basis_parts = [] + + # Define the order of orbitals to check + orbital_types = ['s', 'p', 'd', 'f', 'g', 'h'] + + found_any = False + + for orb_type in orbital_types: + # Regex to find "Number of Sorbital--> 2" (case insensitive for S/P/D/F part of "Sorbital") + # The file example shows "Sorbital", "Porbital", "Dorbital" + pattern = re.compile(rf"Number of {orb_type}orbital-->\s+(\d+)", re.IGNORECASE) + match = pattern.search(content) + + if match: + count = int(match.group(1)) + if count > 0: + basis_parts.append(f"{count}{orb_type}") + found_any = True + + if not found_any: + # Fallback or error if no known orbital pattern is found + # It's possible the file format is completely different, but per user request, we targeting this specific format. + raise ValueError(f"No valid orbital counts found in {filepath}. Expected lines like 'Number of Sorbital--> 2'.") + + return "".join(basis_parts) diff --git a/dptb/utils/pardiso_wrapper.py b/dptb/utils/pardiso_wrapper.py new file mode 100644 index 00000000..3dfd52af --- /dev/null +++ b/dptb/utils/pardiso_wrapper.py @@ -0,0 +1,380 @@ +# coding: utf-8 +import os +import sys +import glob +import ctypes +import warnings +import hashlib +import site +from ctypes.util import find_library + +import numpy as np +import scipy.sparse as sp +from scipy.sparse import SparseEfficiencyWarning + +# Complex matrix types in PARDISO +_COMPLEX_MTYPES = {3, 4, -4, 6, 13} +# Symmetric matrix types in PARDISO +_SYMMETRIC_MTYPES = {-2, 2, -4, 4, 6} + +def _find_mkl_rt(): + """Find and load mkl_rt shared library. + + Returns the ctypes handle to mkl_rt, or None if not found. + """ + libmkl = None + + mkl_rt = os.environ.get('PYPARDISO_MKL_RT') + if mkl_rt is None: + mkl_rt = find_library('mkl_rt') + if mkl_rt is None: + mkl_rt = find_library('mkl_rt.1') + + if mkl_rt is None: + globs = glob.glob( + f'{sys.prefix}/[Ll]ib*/**/*mkl_rt*', recursive=True + ) or glob.glob( + f'{site.USER_BASE}/[Ll]ib*/**/*mkl_rt*', recursive=True + ) + for path in sorted(globs, key=len): + try: + libmkl = ctypes.CDLL(path) + break + except (OSError, ImportError): + pass + else: + try: + libmkl = ctypes.CDLL(mkl_rt) + except (OSError, ImportError): + pass + + return libmkl + + +# Module-level cached handle (avoids re-searching on every solver instantiation) +_MKL_RT_HANDLE = _find_mkl_rt() + + +class PyPardisoSolver: + """ + Python interface to the Intel MKL PARDISO library for solving large sparse linear systems of equations Ax=b. + + Pardiso documentation: https://software.intel.com/en-us/node/470282 + + --- Basic usage --- + matrix type: real (float64) and nonsymetric + methods: solve, factorize + + - use the "solve(A,b)" method to solve Ax=b for x, where A is a sparse CSR (or CSC) matrix and b is a numpy array + - use the "factorize(A)" method first, if you intend to solve the system more than once for different right-hand + sides, the factorization will be reused automatically afterwards + + + --- Advanced usage --- + methods: get_iparm, get_iparms, set_iparm, set_matrix_type, set_phase + + - additional options can be accessed by setting the iparms (see Pardiso documentation for description) + - other matrix types can be chosen with the "set_matrix_type" method. complex matrix types are currently not + supported. pypardiso is only teste for mtype=11 (real and nonsymetric) + - the solving phases can be set with the "set_phase" method + - The out-of-core (OOC) solver either fails or crashes my computer, be careful with iparm[60] + + + --- Statistical info --- + methods: set_statistical_info_on, set_statistical_info_off + + - the Pardiso solver writes statistical info to the C stdout if desired + - if you use pypardiso from within a jupyter notebook you can turn the statistical info on and capture the output + real-time by wrapping your call to "solve" with wurlitzer.sys_pipes() (https://github.com/minrk/wurlitzer, + https://pypi.python.org/pypi/wurlitzer/) + - wurlitzer dosen't work on windows, info appears in notebook server console window if used from jupyter notebook + + + --- Memory usage --- + methods: remove_stored_factorization, free_memory + + - remove_stored_factorization can be used to delete the wrapper's copy of matrix A + - free_memory releases the internal memory of the solver + + """ + + def __init__(self, mtype=11, phase=13, size_limit_storage=5e7): + + # Reuse module-level pre-loaded mkl_rt handle (loaded with RTLD_GLOBAL + # to prevent conflicts with other MKL-linked extensions like vbcsr). + if _MKL_RT_HANDLE is not None: + self.libmkl = _MKL_RT_HANDLE + else: + raise ImportError( + 'Shared library mkl_rt not found. ' + 'Use environment variable PYPARDISO_MKL_RT to provide a custom path.' + ) + + self._mkl_pardiso = self.libmkl.pardiso + + # determine 32bit or 64bit architecture + if ctypes.sizeof(ctypes.c_void_p) == 8: + self._pt_type = (ctypes.c_int64, np.int64) + else: + self._pt_type = (ctypes.c_int32, np.int32) + + self._mkl_pardiso.argtypes = [ctypes.POINTER(self._pt_type[0]), # pt + ctypes.POINTER(ctypes.c_int32), # maxfct + ctypes.POINTER(ctypes.c_int32), # mnum + ctypes.POINTER(ctypes.c_int32), # mtype + ctypes.POINTER(ctypes.c_int32), # phase + ctypes.POINTER(ctypes.c_int32), # n + ctypes.POINTER(None), # a + ctypes.POINTER(ctypes.c_int32), # ia + ctypes.POINTER(ctypes.c_int32), # ja + ctypes.POINTER(ctypes.c_int32), # perm + ctypes.POINTER(ctypes.c_int32), # nrhs + ctypes.POINTER(ctypes.c_int32), # iparm + ctypes.POINTER(ctypes.c_int32), # msglvl + ctypes.POINTER(None), # b + ctypes.POINTER(None), # x + ctypes.POINTER(ctypes.c_int32)] # error + + self._mkl_pardiso.restype = None + + self.pt = np.zeros(64, dtype=self._pt_type[1]) + self.iparm = np.zeros(64, dtype=np.int32) + self.perm = np.zeros(0, dtype=np.int32) + + self.mtype = mtype + self.phase = phase + self.msglvl = False + + self.factorized_A = sp.csr_matrix((0, 0)) + self.size_limit_storage = size_limit_storage + self._solve_transposed = False + + @property + def _is_complex(self): + """Return True if the current mtype is a complex matrix type.""" + return self.mtype in _COMPLEX_MTYPES + + @property + def _dtype(self): + """Return the expected numpy dtype for the current mtype.""" + return np.complex128 if self._is_complex else np.float64 + + def factorize(self, A): + """ + Factorize the matrix A, the factorization will automatically be used if the same matrix A is passed to the + solve method. This will drastically increase the speed of solve, if solve is called more than once for the + same matrix A + + --- Parameters --- + A: sparse square CSR matrix (scipy.sparse.csr.csr_matrix), CSC matrix also possible + """ + + self._check_A(A) + + if A.nnz > self.size_limit_storage: + self.factorized_A = self._hash_csr_matrix(A) + else: + self.factorized_A = A.copy() + + self.set_phase(12) + b = np.zeros((A.shape[0], 1), dtype=self._dtype) + self._call_pardiso(A, b) + + def solve(self, A, b): + """ + solve Ax=b for x + + --- Parameters --- + A: sparse square CSR matrix (scipy.sparse.csr.csr_matrix), CSC matrix also possible + b: numpy ndarray + right-hand side(s), b.shape[0] needs to be the same as A.shape[0] + + --- Returns --- + x: numpy ndarray + solution of the system of linear equations, same shape as input b + """ + + self._check_A(A) + b = self._check_b(A, b) + + if self._is_already_factorized(A): + self.set_phase(33) + else: + self.set_phase(13) + + x = self._call_pardiso(A, b) + + return x + + def _is_already_factorized(self, A): + if isinstance(self.factorized_A, str): + return self._hash_csr_matrix(A) == self.factorized_A + else: + return self._csr_matrix_equal(A, self.factorized_A) + + def _csr_matrix_equal(self, a1, a2): + return all((np.array_equal(a1.indptr, a2.indptr), + np.array_equal(a1.indices, a2.indices), + np.array_equal(a1.data, a2.data))) + + def _hash_csr_matrix(self, matrix): + return (hashlib.sha1(matrix.indices).hexdigest() + + hashlib.sha1(matrix.indptr).hexdigest() + + hashlib.sha1(matrix.data).hexdigest()) + + def _check_A(self, A): + if A.shape[0] != A.shape[1]: + raise ValueError('Matrix A needs to be square, but has shape: {}'.format(A.shape)) + + if sp.issparse(A) and A.format == "csr": + self._solve_transposed = False + self.set_iparm(12, 0) + elif sp.issparse(A) and A.format == "csc": + self._solve_transposed = True + self.set_iparm(12, 1) + else: + msg = 'PyPardiso requires matrix A to be in CSR or CSC format, but matrix A is: {}'.format(type(A)) + raise TypeError(msg) + + # scipy allows unsorted csr-indices, which lead to completely wrong pardiso results + if not A.has_sorted_indices: + A.sort_indices() + + # scipy allows csr matrices with empty rows. a square matrix with an empty row is singular. calling + # pardiso with a matrix A that contains empty rows leads to a segfault, same applies for csc with + # empty columns + if not np.diff(A.indptr).all(): + row_col = 'column' if self._solve_transposed else 'row' + raise ValueError('Matrix A is singular, because it contains empty {}(s)'.format(row_col)) + + expected_dtype = self._dtype + if A.dtype != expected_dtype: + raise TypeError( + 'Matrix A has dtype {}, but mtype={} requires dtype {}'.format( + A.dtype, self.mtype, expected_dtype)) + + def _check_b(self, A, b): + if sp.issparse(b): + warnings.warn('PyPardiso requires the right-hand side b to be a dense array for maximum efficiency', + SparseEfficiencyWarning) + b = b.todense() + + # pardiso expects fortran (column-major) order for b + if not b.flags.f_contiguous: + b = np.asfortranarray(b) + + if b.shape[0] != A.shape[0]: + raise ValueError("Dimension mismatch: Matrix A {} and array b {}".format(A.shape, b.shape)) + + expected_dtype = self._dtype + if b.dtype != expected_dtype: + if self._is_complex: + # For complex types, try to cast compatible dtypes + if np.issubdtype(b.dtype, np.complexfloating) or np.issubdtype(b.dtype, np.floating): + warnings.warn( + "Array b's data type was converted from {} to {}".format(b.dtype, expected_dtype), + PyPardisoWarning) + b = b.astype(expected_dtype) + else: + raise TypeError('Dtype {} for array b is not supported'.format(b.dtype)) + else: + if b.dtype in [np.float16, np.float32, np.int16, np.int32, np.int64]: + warnings.warn("Array b's data type was converted from {} to float64".format(str(b.dtype)), + PyPardisoWarning) + b = b.astype(np.float64) + else: + raise TypeError('Dtype {} for array b is not supported'.format(str(b.dtype))) + + return b + + def _call_pardiso(self, A, b): + + x = np.zeros_like(b) + pardiso_error = ctypes.c_int32(0) + c_int32_p = ctypes.POINTER(ctypes.c_int32) + + # 1-based indexing + ia = A.indptr.astype(np.int32) + 1 + ja = A.indices.astype(np.int32) + 1 + + self._mkl_pardiso(self.pt.ctypes.data_as(ctypes.POINTER(self._pt_type[0])), # pt + ctypes.byref(ctypes.c_int32(1)), # maxfct + ctypes.byref(ctypes.c_int32(1)), # mnum + ctypes.byref(ctypes.c_int32(self.mtype)), # mtype -> 11 for real-nonsymetric + ctypes.byref(ctypes.c_int32(self.phase)), # phase -> 13 + ctypes.byref(ctypes.c_int32(A.shape[0])), # N -> number of equations/size of matrix + A.data.ctypes.data_as(ctypes.c_void_p), # A -> non-zero entries in matrix + ia.ctypes.data_as(c_int32_p), # ia -> csr-indptr + ja.ctypes.data_as(c_int32_p), # ja -> csr-indices + self.perm.ctypes.data_as(c_int32_p), # perm -> empty + ctypes.byref(ctypes.c_int32(1 if b.ndim == 1 else b.shape[1])), # nrhs + self.iparm.ctypes.data_as(c_int32_p), # iparm-array + ctypes.byref(ctypes.c_int32(self.msglvl)), # msg-level -> 1: statistical info is printed + b.ctypes.data_as(ctypes.c_void_p), # b -> right-hand side vector/matrix + x.ctypes.data_as(ctypes.c_void_p), # x -> output + ctypes.byref(pardiso_error)) # pardiso error + + if pardiso_error.value != 0: + raise PyPardisoError(pardiso_error.value) + else: + return np.ascontiguousarray(x) # change memory-layout back from fortran to c order + + def get_iparms(self): + """Returns a dictionary of iparms""" + return dict(enumerate(self.iparm, 1)) + + def get_iparm(self, i): + """Returns the i-th iparm (1-based indexing)""" + return self.iparm[i-1] + + def set_iparm(self, i, value): + """set the i-th iparm to 'value' (1-based indexing)""" + if i not in {1, 2, 4, 5, 6, 8, 10, 11, 12, 13, 18, 19, 21, 24, 25, 27, 28, 31, 34, 35, 36, 37, 56, 60}: + warnings.warn('{} is no input iparm. See the Pardiso documentation.'.format(value), PyPardisoWarning) + self.iparm[i-1] = value + + def set_matrix_type(self, mtype): + """Set the matrix type (see Pardiso documentation)""" + self.mtype = mtype + + def set_statistical_info_on(self): + """Display statistical info (appears in notebook server console window if pypardiso is + used from jupyter notebook, use wurlitzer to redirect info to the notebook)""" + self.msglvl = 1 + + def set_statistical_info_off(self): + """Turns statistical info off""" + self.msglvl = 0 + + def set_phase(self, phase): + """Set the phase(s) for the solver. See the Pardiso documentation for details.""" + self.phase = phase + + def remove_stored_factorization(self): + """removes the stored factorization, this will free the memory in python, but the factorization in pardiso + is still accessible with a direct call to self._call_pardiso(A,b) with phase=33""" + self.factorized_A = sp.csr_matrix((0, 0)) + + def free_memory(self, everything=False): + """release mkl's internal memory, either only for the factorization (ie the LU-decomposition) or all of + mkl's internal memory if everything=True""" + self.remove_stored_factorization() + A = sp.csr_matrix((0, 0)) + b = np.zeros(0) + self.set_phase(-1 if everything else 0) + self._call_pardiso(A, b) + self.set_phase(13) + + +class PyPardisoWarning(UserWarning): + pass + + +class PyPardisoError(Exception): + + def __init__(self, value): + self.value = value + + def __str__(self): + return ('The Pardiso solver failed with error code {}. ' + 'See Pardiso documentation for details.'.format(self.value)) \ No newline at end of file diff --git a/dptb/utils/tools.py b/dptb/utils/tools.py index 8c8e5f1b..4132cbb4 100644 --- a/dptb/utils/tools.py +++ b/dptb/utils/tools.py @@ -56,6 +56,22 @@ def float2comlex(dtype): raise ValueError("the dtype is not supported! now only float64, float32 is supported!") return cdtype +def tdtype2ndtype(dtype): + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + + if dtype is torch.float32: + ndtype = np.float32 + elif dtype is torch.float64: + ndtype = np.float64 + elif dtype is torch.complex64: + ndtype = np.complex64 + elif dtype is torch.complex128: + ndtype = np.complex128 + else: + raise ValueError("the dtype is not supported! now only float64, float32, complex64, complex128 is supported!") + return ndtype + def flatten_dict(dictionary): queue = list(dictionary.items()) diff --git a/pyproject.toml b/pyproject.toml index 4a3827a5..eb3d9f96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "tensorboard", "seekpath", "rich>=13.0.0", + "vbcsr>=0.2.2" ] [project.optional-dependencies]