Skip to content
30 changes: 30 additions & 0 deletions dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer, TensorBoardMonitor
from dptb.plugins.train_logger import Logger
from dptb.utils.argcheck import normalize, collect_cutoffs, chk_avg_per_iter
from dptb.utils.orbital_parser import parse_orbital_file
from dptb.plugins.saver import Saver
from typing import Dict, List, Optional, Any
from dptb.utils.tools import j_loader, setup_seed, j_must_have
Expand Down Expand Up @@ -90,6 +91,35 @@ def train(

jdata = j_loader(INPUT)
jdata = normalize(jdata)

# Validate and process orbital files in basis
if jdata.get("common_options") and jdata["common_options"].get("basis"):
orbital_files_content = {}
for elem, value in jdata["common_options"]["basis"].items():
if isinstance(value, str) and os.path.isfile(value):
# strict check for e3tb method
# Check if model_options exists and has prediction method
model_opts = jdata.get("model_options", {})
pred_opts = model_opts.get("prediction", {})
# normalize might handle defaults, but safely check here
if pred_opts.get("method", "e3tb") != "e3tb":
raise ValueError(f"Orbital files in 'basis' are only supported for the 'e3tb' method. Found method: {pred_opts.get('method')}")

# Also checking if we are in a mixed model which might be different,
# but usually orbital files for basis imply the main basis handling.
# transform logic
try:
parsed_basis = parse_orbital_file(value)
with open(value, 'r') as f:
orbital_files_content[elem] = f.read()

jdata["common_options"]["basis"][elem] = parsed_basis
log.info(f"Parsed orbital file for {elem}: {value} -> {parsed_basis}")
except Exception as e:
raise ValueError(f"Failed to parse orbital file {value} for element {elem}: {e}")

if orbital_files_content:
jdata["common_options"]["orbital_files_content"] = orbital_files_content
# update basis if init_model or restart
# update jdata
# this is not necessary, because if we init model from checkpoint, the build_model will load the model_options from checkpoints if not provided
Expand Down
2 changes: 2 additions & 0 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def __init__(
device=self.device,
)

if kwargs.get("orbital_files_content"):
self.orbital_files_content = kwargs["orbital_files_content"]

def forward(self, data: AtomicDataDict.Type):
if data.get(AtomicDataDict.EDGE_TYPE_KEY, None) is None:
Expand Down
182 changes: 182 additions & 0 deletions dptb/nn/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
from dptb.data.transforms import OrbitalMapper
from dptb.data import AtomicDataDict
import logging
try:
from dptb.utils.pardiso_wrapper import PyPardisoSolver
from dptb.utils.feast_wrapper import FeastSolver
from scipy.sparse.linalg import eigsh, LinearOperator
except ImportError:
PyPardisoSolver = None
FeastSolver = None
eigsh = None
LinearOperator = None

log = logging.getLogger(__name__)

class Eigenvalues(nn.Module):
Expand Down Expand Up @@ -113,6 +123,178 @@ def forward(self,
data[AtomicDataDict.KPOINT_KEY] = kpoints

return data

class PardisoEig:
def __init__(self, sigma: float = 0.0, neig: int = 10, mode: str = 'normal'):
"""
Solver using Pardiso for shift-invert eigenvalue problems.

Args:
sigma: Shift value (target energy).
neig: Number of eigenvalues to solve for.
mode: Eigsh mode ('normal', 'buckling', 'cayley').
"""
if PyPardisoSolver is None or eigsh is None:
raise ImportError("PardisoEig requires MKL (pypardiso) and scipy.sparse.linalg")

self.sigma = sigma
self.neig = neig
self.mode = mode


def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False):
"""
Solve eigenvalues for given k-points.

Args:
h_container: vbcsr.ImageContainer for Hamiltonian.
s_container: vbcsr.ImageContainer for Overlap (can be None).
kpoints: Array of k-points (Nk, 3).
return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False.

Returns:
list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True).
"""

# Ensure kpoints is numpy array
if isinstance(kpoints, torch.Tensor):
kpoints = kpoints.cpu().numpy()

eigvals_list = []
eigvecs_list = []

for k in kpoints:
hk = h_container.sample_k(k, symm=True)

if s_container is not None:
sk = s_container.sample_k(k, symm=True)
hk -= self.sigma * sk
A = hk.to_scipy(format="csr")
M = sk
else:
hk.shift(-self.sigma)
A = hk.to_scipy(format="csr")
M = None

A.sort_indices()
A.sum_duplicates()
N = A.shape[0]

# Try PARDISO first, fall back to scipy SuperLU if PARDISO fails
# (MKL PARDISO has a known bug with certain block-structured patterns)
solver = PyPardisoSolver(mtype=13)
solver.factorize(A)

def matvec(b):
return solver.solve(A, b)

Op = LinearOperator((N, N), matvec=matvec, dtype=A.dtype)

try:
# Use larger NCV to help convergence, especially for clustered eigenvalues
ncv = max(2*self.neig + 1, 20)
vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=0.0, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)
except Exception:
# Retry with larger NCV if ARPACK fails (e.g. error 3: No shifts could be applied)
# This often happens when eigenvalues are clustered near the shift
ncv = max(5*self.neig, 50)
vals, vecs = eigsh(A=hk, M=M, k=self.neig, sigma=0.0, OPinv=Op, mode=self.mode, which="LM", ncv=ncv)

eigvals_list.append(vals + self.sigma)
if return_eigenvectors:
eigvecs_list.append(vecs)

if return_eigenvectors:
return eigvals_list, eigvecs_list
else:
return eigvals_list

class FEASTEig:
def __init__(self, emin: float = -1.0, emax: float = 1.0, m0: Optional[int] = None,
max_refinement: int = 3, uplo: str = 'U', extract_triangular: bool = True):
"""
Solver using FEAST algorithm for finding eigenvalues in a given interval.

Args:
emin, emax: Energy interval [emin, emax].
m0: Initial subspace size estimate.
max_refinement: Number of refinements if subspace is too small.
uplo: 'U' (Upper) or 'L' (Lower) triangular part to use.
extract_triangular: Whether to extract triangular part automatically.
"""

if FeastSolver is None:
raise ImportError("FEAST solver not available")

self.emin = emin
self.emax = emax
self.m0 = m0
self.max_refinement = max_refinement
self.uplo = uplo
self.extract_triangular = extract_triangular

# Initialize solver to check availability
try:
self.solver = FeastSolver()
except ImportError as e:
raise ImportError(f"FEAST solver not available: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to initialize FeastSolver: {e}") from e

def solve(self, h_container, s_container, kpoints: Union[list, torch.Tensor, np.ndarray], return_eigenvectors: bool = False):
"""
Solve eigenvalues for given k-points using FEAST.

Args:
h_container: Container for Hamiltonian (must support sample_k().to_scipy()).
s_container: Container for Overlap (can be None).
kpoints: Array of k-points.
return_eigenvectors: If True, return (eigenvalues, eigenvectors). Default False.

Returns:
list of eigenvalues arrays (and eigenvectors arrays if return_eigenvectors=True).
"""
if isinstance(kpoints, torch.Tensor):
kpoints = kpoints.cpu().numpy()

eigvals_list = []
eigvecs_list = []

for k in kpoints:
# Get Hamiltonian and Overlap at k
# Assuming h_container.sample_k returns object with .to_scipy()
hk_obj = h_container.sample_k(k, symm=True)
if hasattr(hk_obj, 'to_scipy'):
hk = hk_obj.to_scipy(format="csr")
else:
# Fallback if it checks sparse type
hk = hk_obj

if s_container is not None:
sk_obj = s_container.sample_k(k, symm=True)
if hasattr(sk_obj, 'to_scipy'):
sk = sk_obj.to_scipy(format="csr")
else:
sk = sk_obj
else:
sk = None

# Solve
evals, vecs = self.solver.solve(
hk, M=sk, emin=self.emin, emax=self.emax,
m0=self.m0, max_refinement=self.max_refinement,
uplo=self.uplo, extract_triangular=self.extract_triangular
)

eigvals_list.append(evals)
if return_eigenvectors:
eigvecs_list.append(vecs)

if return_eigenvectors:
return eigvals_list, eigvecs_list
else:
return eigvals_list


class Eigh(nn.Module):
def __init__(
Expand Down
Loading