From 094a3e55060b18aa58ad6c3fe442a3f69ba458db Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Sat, 6 Sep 2025 16:09:32 +0800 Subject: [PATCH 1/7] chore: universal model compatible 1. update neighbor search to new ase API 2. add scale type choice: no_scale, scale_wo_back_grad, scale_w_back_grad 3. add norm_eps arg 4. add HamilLossAbsMAE for the ham abs mae 5. add adamW optimizer --- dptb/data/AtomicData.py | 46 ++++++++++++++++----- dptb/nn/deeptb.py | 11 +++-- dptb/nn/embedding/lem.py | 20 ++++++--- dptb/nnops/loss.py | 89 ++++++++++++++++++++++++++++++++++++++++ dptb/utils/argcheck.py | 12 ++++-- dptb/utils/tools.py | 6 +-- 6 files changed, 158 insertions(+), 26 deletions(-) diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py index 46b5dd375..a90a85bc8 100644 --- a/dptb/data/AtomicData.py +++ b/dptb/data/AtomicData.py @@ -15,6 +15,9 @@ from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress +from ase.neighborlist import NewPrimitiveNeighborList +from ase.data import chemical_symbols +import itertools import torch import e3nn.o3 @@ -971,17 +974,40 @@ def neighbor_list_and_relative_vec( # ASE dependent part temp_cell = ase.geometry.complete_cell(temp_cell) - - first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( - "ijS", - pbc, - temp_cell, - temp_pos, - cutoff=float(_r_max), - self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! - use_scaled_positions=False, +################################################################################## +###################################### 新代码 ######################################## + elements = np.unique(atomic_numbers).tolist() + pair_cutoffs = {} + for elem1, elem2 in itertools.combinations_with_replacement(elements, 2): + pair_cutoffs[(elem1, elem2)] = max(r_max[chemical_symbols[elem1]], r_max[chemical_symbols[elem2]]) + + nl = NewPrimitiveNeighborList( + cutoffs=10, + skin=0.0, + self_interaction=self_interaction, + bothways=True, + use_scaled_positions=False ) - + nl.cutoffs = pair_cutoffs + nl.update(pbc, temp_cell, temp_pos, atomic_numbers) + first_idex, second_idex, shifts = nl.pair_first, nl.pair_second, nl.offset_vec + mask_r = False + _r_max = max(pair_cutoffs.values()) + + ################################################################################## + ##################################### 老代码 ######################################### + # first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( + # "ijS", + # pbc, + # temp_cell, + # temp_pos, + # cutoff=float(_r_max), + # self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! + # use_scaled_positions=False, + # ) + + ################################################################################## + ################################################################################## # Eliminate true self-edges that don't cross periodic boundaries # if not self_interaction: # bad_edge = first_idex == second_idex diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 13fbf9846..7bdd481f8 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -68,6 +68,7 @@ def __init__( dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), transform: bool = True, + scale_type: str = 'scale_w_back_grad', **kwargs, ): @@ -103,7 +104,7 @@ def __init__( self.device = device self.model_options = {"embedding": embedding.copy(), "prediction": prediction.copy()} self.transform = transform - + self.scale_type = scale_type self.method = prediction.get("method", "e3tb") # self.soc = prediction.get("soc", False) @@ -298,9 +299,11 @@ def forward(self, data: AtomicDataDict.Type): data = self.embedding(data) if hasattr(self, "overlap") and self.method == "sktb": data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY] - - data = self.node_prediction_h(data) - data = self.edge_prediction_h(data) + + if self.scale_type != 'no_scale': + data = self.node_prediction_h(data) + data = self.edge_prediction_h(data) + if hasattr(self, "overlap"): data = self.edge_prediction_s(data) data[AtomicDataDict.NODE_OVERLAP_KEY] = self.overlaponsite_param[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] diff --git a/dptb/nn/embedding/lem.py b/dptb/nn/embedding/lem.py index 2987e71cd..abdc1b067 100644 --- a/dptb/nn/embedding/lem.py +++ b/dptb/nn/embedding/lem.py @@ -39,6 +39,7 @@ def __init__( avg_num_neighbors: Optional[float] = None, # cutoffs r_start_cos_ratio: float = 0.8, + norm_eps: float = 1e-8, PolynomialCutoff_p: float = 6, cutoff_type: str = "polynomial", # general hyperparameters: @@ -131,6 +132,7 @@ def __init__( cutoff_type=cutoff_type, device=device, dtype=dtype, + norm_eps=norm_eps ) self.layers = torch.nn.ModuleList() @@ -235,6 +237,7 @@ def __init__( latent_dim: int=128, # cutoffs r_start_cos_ratio: float = 0.8, + norm_eps: float = 1e-8, PolynomialCutoff_p: float = 6, cutoff_type: str = "polynomial", device: Union[str, torch.device] = torch.device("cpu"), @@ -290,7 +293,7 @@ def __init__( self.sln_n = SeperableLayerNorm( irreps=self.irreps_out, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -300,7 +303,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.irreps_out, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -438,6 +441,7 @@ def __init__( irreps_in: o3.Irreps, irreps_out: o3.Irreps, latent_dim: int, + norm_eps: float = 1e-8, radial_emb: bool=False, radial_channels: list=[128, 128], res_update: bool = True, @@ -470,7 +474,7 @@ def __init__( self.sln = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -480,7 +484,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.edge_irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -614,6 +618,7 @@ def __init__( irreps_in: o3.Irreps, irreps_out: o3.Irreps, latent_dim: int, + norm_eps: float = 1e-8, latent_channels: list=[128, 128], radial_emb: bool=False, radial_channels: list=[128, 128], @@ -675,7 +680,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -685,7 +690,7 @@ def __init__( self.sln_n = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -806,6 +811,7 @@ def __init__( tp_radial_emb: bool=False, tp_radial_channels: list=[128, 128], # MLP parameters: + norm_eps: float = 1e-8, latent_channels: list=[128, 128], latent_dim: int=128, res_update: bool = True, @@ -842,6 +848,7 @@ def __init__( res_update_ratios_learnable=res_update_ratios_learnable, dtype=dtype, device=device, + norm_eps=norm_eps ) self.node_update = UpdateNode( @@ -857,6 +864,7 @@ def __init__( avg_num_neighbors=avg_num_neighbors, dtype=dtype, device=device, + norm_eps=norm_eps ) def forward(self, latents, node_features, edge_features, node_onehot, edge_index, edge_vector, atom_type, cutoff_coeffs, active_edges): diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py index 4e8d6fcd6..055987ee2 100644 --- a/dptb/nnops/loss.py +++ b/dptb/nnops/loss.py @@ -643,6 +643,95 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): return 0.5 * (onsite_loss + hopping_loss) +@Loss.register("hamil_abs_mae") +class HamilLossAbsMAE(nn.Module): + def __init__( + self, + basis: Dict[str, Union[str, list]] = None, + idp: Union[OrbitalMapper, None] = None, + overlap: bool = False, + onsite_shift: bool = False, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu"), + **kwargs, + ): + + super(HamilLossAbsMAE, self).__init__() + self.loss1 = nn.L1Loss() + self.loss2 = nn.MSELoss() + self.overlap = overlap + self.device = device + self.onsite_shift = onsite_shift + + if basis is not None: + self.idp = OrbitalMapper(basis, method="e3tb", device=self.device) + if idp is not None: + assert idp == self.idp, "The basis of idp and basis should be the same." + else: + assert idp is not None, "Either basis or idp should be provided." + self.idp = idp + + def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): + # mask the data + + # data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.) + # data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.) + + if self.onsite_shift: + batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0])) + # assert batch.max() == 0, "The onsite shift is only supported for batchsize=1." + mu = data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \ + ref_data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] + if batch.max() == 0: # when batchsize is zero + mu = mu.mean().detach() + ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[ + AtomicDataDict.NODE_OVERLAP_KEY] + ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[ + AtomicDataDict.EDGE_OVERLAP_KEY] + elif batch.max() >= 1: + slices = [data["__slices__"]["pos"][i] - data["__slices__"]["pos"][i - 1] for i in + range(1, len(data["__slices__"]["pos"]))] + slices = [0] + slices + ndiag_batch = torch.stack([i.sum() for i in + self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split( + slices)]) + ndiag_batch = torch.cumsum(ndiag_batch, dim=0) + mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i + 1]].mean() for i in range(len(ndiag_batch) - 1)]) + mu = mu.detach() + ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[ + batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY] + edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, + device=self.device) + for i in range(1, batch.max().item() + 1): + edge_mu_index[data["__slices__"]["edge_index"][i]:data["__slices__"]["edge_index"][i + 1]] += i + ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu[ + edge_mu_index, None] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY] + + # onsite loss + pre_onsite = data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + tgt_onsite = ref_data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + + # hopping loss + pre_hopping = data[AtomicDataDict.EDGE_FEATURES_KEY][ + self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + tgt_hopping = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][ + self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + + pre = torch.cat([pre_onsite, pre_hopping], dim=0) + tgt = torch.cat([tgt_onsite, tgt_hopping], dim=0) + + total_loss = self.loss1(pre, tgt) + return total_loss + + @Loss.register("hamil_wt") class HamilLossWT(nn.Module): def __init__( diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 8598156f3..7fa3974ae 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -115,9 +115,10 @@ def train_options(): doc_sliding_win_size = "Sliding window size for the average of the latest iterations' loss. Used for the reduce on plateau learning rate scheduler in case of the pairing of large dataset and small batch size. Default: `50`" doc_optimizer = "\ - The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ + The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `AdamW`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ - `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\ + - `AdamW`: [AdamW: Decoupled Weight Decay Regularization.](https://arxiv.org/abs/1711.05101)\n\n\ - `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\ - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " @@ -231,10 +232,11 @@ def LBFGS(): ] def optimizer(): - doc_type = "select type of optimizer, support type includes: `Adam`, `SGD` and `LBFGS`. Default: `Adam`" + doc_type = "select type of optimizer, support type includes: `Adam`, `AdamW`, `SGD` and `LBFGS`. Default: `Adam`" return Variant("type", [ Argument("Adam", dict, Adam()), + Argument("AdamW", dict, Adam()), Argument("SGD", dict, SGD()), Argument("RMSprop", dict, RMSprop()), Argument("LBFGS", dict, LBFGS()), @@ -635,6 +637,7 @@ def slem(): Argument("res_update_ratios", float, optional=True, default=0.5, doc="The ratios of residual update, should in (0,1)."), Argument("res_update_ratios_learnable", bool, optional=True, default=False, doc="Whether to make the ratios of residual update learnable."), Argument("universal", bool, optional=True, default=False, doc=doc_universal), + Argument("norm_eps", float, optional=True, default=1e-8, doc="eps in SeperableLayerNorm."), ] @@ -667,12 +670,14 @@ def e3tb_prediction(): doc_neurons = "neurons in the neural network." doc_activation = "activation function." doc_if_batch_normalized = "if to turn on batch normalization" + doc_scale_type = "Which scale method to use. Can be no_scale, scale_wo_back_grad, scale_w_back_grad" nn = [ Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable), Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable), Argument("neurons", list, optional=True, default=None, doc=doc_neurons), Argument("activation", str, optional=True, default="tanh", doc=doc_activation), + Argument("scale_type", str, optional=True, default="no_scale", doc=doc_scale_type), Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized), ] @@ -1756,9 +1761,10 @@ def normalize_skf2nnsk(data): doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`" doc_optimizer = "\ - The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ + The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `AdamW`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ - `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\ + - `AdamW`: [AdamW: Decoupled Weight Decay Regularization.](https://arxiv.org/abs/1711.05101)\n\n\ - `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\ - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " diff --git a/dptb/utils/tools.py b/dptb/utils/tools.py index 8c8e5f1b5..d276be9f7 100644 --- a/dptb/utils/tools.py +++ b/dptb/utils/tools.py @@ -125,8 +125,6 @@ def update_dict_with_warning(dict_input, update_list, update_value): return reconstruct_dict(flatten_input_dict) - - def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -138,6 +136,8 @@ def setup_seed(seed): def get_optimizer(type: str, model_param, lr: float, **options: dict): if type == 'Adam': optimizer = optim.Adam(params=model_param, lr=lr, **options) + elif type == 'AdamW': + optimizer = optim.AdamW(params=model_param, lr=lr, **options) elif type == 'SGD': optimizer = optim.SGD(params=model_param, lr=lr, **options) elif type == 'RMSprop': @@ -145,7 +145,7 @@ def get_optimizer(type: str, model_param, lr: float, **options: dict): elif type == 'LBFGS': optimizer = optim.LBFGS(params=model_param, lr=lr, **options) else: - raise RuntimeError("Optimizer should be Adam/SGD/RMSprop, not {}".format(type)) + raise RuntimeError("Optimizer should be Adam/AdamW/SGD/RMSprop, not {}".format(type)) return optimizer def get_lr_scheduler(type: str, optimizer: optim.Optimizer, **sch_options): From d1283fc14eaf71cbfc72efca1e07ab1f5a9f7e39 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Sat, 6 Sep 2025 16:43:03 +0800 Subject: [PATCH 2/7] 0.5.3 e3nn to be compatible with 2.0.0 torch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0e455e5f7..6dc62de64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ pyyaml = "*" future = "*" dargs = "0.4.4" xitorch = "0.3.0" -e3nn = ">=0.5.1" +e3nn = ">=0.5.1,<=0.5.3" torch-runstats = "0.2.0" torch_scatter = "2.1.2" torch_geometric = ">=2.4.0" From 69ca22fc709d4c86c9905da11aa4684f3f11a4dd Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:14:00 +0800 Subject: [PATCH 3/7] feat: add ABACUS CSR file output feature --- dptb/postprocess/__init__.py | 4 +- dptb/postprocess/write_abacus_csr_file.py | 207 ++++++++++++++++++++++ 2 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 dptb/postprocess/write_abacus_csr_file.py diff --git a/dptb/postprocess/__init__.py b/dptb/postprocess/__init__.py index be849ddfc..6f51a6525 100644 --- a/dptb/postprocess/__init__.py +++ b/dptb/postprocess/__init__.py @@ -1,11 +1,11 @@ from .bandstructure import Band from .totbplas import TBPLaS from .write_block import write_block - +from .write_abacus_csr_file import write_blocks_to_abacus_csr __all__ = [ Band, TBPLaS, write_block, - + write_blocks_to_abacus_csr ] \ No newline at end of file diff --git a/dptb/postprocess/write_abacus_csr_file.py b/dptb/postprocess/write_abacus_csr_file.py new file mode 100644 index 000000000..d8c8aae44 --- /dev/null +++ b/dptb/postprocess/write_abacus_csr_file.py @@ -0,0 +1,207 @@ +import os +import lmdb +import pickle +import re +import numpy as np +from scipy.sparse import csr_matrix, coo_matrix +from collections import defaultdict +import ase.data +from scipy.linalg import block_diag +from dftio.constants import ABACUS2DFTIO + +# DFTIO -> ABACUS +DFTIO2ABACUS = {l: M.T.astype(np.float32) for l, M in ABACUS2DFTIO.items()} + +ORBITAL_MAP = {'s': 0, 'p': 1, 'd': 2, 'f': 3, 'g': 4, 'h': 5} +KEY_RE = re.compile(r'^\s*(-?\d+)[ _](-?\d+)[ _](-?\d+)[ _](-?\d+)[ _](-?\d+)\s*$') +H_FACTOR = 13.605698 # Ryd -> eV factor for Hamiltonian + + +def parse_basis_to_l_list(basis_str): + """'2s2p1d' or 'spd' -> [0,0,1,1,2].""" + if basis_str is None: + return [] + s = str(basis_str).strip().lower() + if s == "": + return [] + tokens = re.findall(r'(\d*)([spdfgh])', s) + lst = [] + for num, ch in tokens: + cnt = int(num) if num else 1 + if ch not in ORBITAL_MAP: + raise ValueError(f"Unsupported orbital '{ch}' in '{basis_str}'") + lst.extend([ORBITAL_MAP[ch]] * cnt) + return lst + + +def find_basis_for_Z_or_symbol(basis_dict, Z): + """Find basis string for atomic number Z (multiple key forms).""" + if Z in basis_dict: + return basis_dict[Z] + sym = ase.data.chemical_symbols[Z] + for key_try in (sym, sym.capitalize(), sym.upper(), str(Z)): + if key_try in basis_dict: + return basis_dict[key_try] + for k, v in basis_dict.items(): + if isinstance(k, str) and k.lower() == sym.lower(): + return v + return None + + +def transform_2_ABACUS(mat, l_lefts, l_rights): + """Transform block from DFTIO ordering to ABACUS ordering.""" + if max(*(list(l_lefts) + list(l_rights))) > 5: + raise NotImplementedError("Only support l = s..h.") + left_mats = [DFTIO2ABACUS[l] for l in l_lefts] + right_mats = [DFTIO2ABACUS[l] for l in l_rights] + left = block_diag(*left_mats) if left_mats else np.eye(0, dtype=np.float32) + right = block_diag(*right_mats) if right_mats else np.eye(0, dtype=np.float32) + return left @ mat @ right.T + + +def write_abacus_csr_format(matrix_dict, matrix_symbol, output_path, step=0): + """Write mapping 'Rx_Ry_Rz' -> csr_matrix into ABACUS text CSR.""" + if not matrix_dict: + print(f"Warning: empty matrix_dict for {matrix_symbol}") + return + first = next(iter(matrix_dict)) + norbits = matrix_dict[first].shape[0] + num_blocks = len(matrix_dict) + with open(output_path, 'w') as f: + f.write(f"STEP: {step}\n") + f.write(f"Matrix Dimension of {matrix_symbol}(R): {norbits}\n") + f.write(f"Matrix number of {matrix_symbol}(R): {num_blocks}\n") + for r_key, sparse_mat in matrix_dict.items(): + r_vector_str = r_key.replace('_', ' ') + nnz = int(sparse_mat.nnz) + f.write(f"{r_vector_str} {nnz}\n") + if nnz > 0: + np.savetxt(f, sparse_mat.data.reshape(1, -1), fmt='%.8e') + np.savetxt(f, sparse_mat.indices.reshape(1, -1), fmt='%d') + np.savetxt(f, sparse_mat.indptr.reshape(1, -1), fmt='%d') + else: + f.write("\n\n\n") + # print(f"Wrote {num_blocks} blocks to {output_path}") + + +def write_blocks_to_abacus_csr(atomic_numbers, basis_dict, blocks_dict, matrix_symbol, output_path, step=0): + """ + Entry function: + atomic_numbers: per-site Z array-like + basis_dict: parse_orbital_files result + blocks_dict: mapping 'i_j_Rx_Ry_Rz' -> small block (DFTIO ordering) + matrix_symbol: 'H'/'S'/'D' + """ + atomic_numbers = np.asarray(atomic_numbers, dtype=int) + if atomic_numbers.size == 0: + raise ValueError("empty atomic_numbers") + + # choose factor + factor = H_FACTOR if str(matrix_symbol).upper() == 'H' else 1.0 + + # element -> l-list + element_l_lists = {} + for Z in np.unique(atomic_numbers): + basis_str = find_basis_for_Z_or_symbol(basis_dict, int(Z)) + if basis_str is None: + element_l_lists[int(Z)] = [0] + else: + ll = parse_basis_to_l_list(basis_str) + element_l_lists[int(Z)] = ll if ll else [0] + + # site norbits + site_norbits = np.array([sum(2 * l + 1 for l in element_l_lists[int(Z)]) for Z in atomic_numbers], dtype=int) + site_norbits_cumsum = np.cumsum(site_norbits) + norbits = int(site_norbits_cumsum[-1]) + + # aggregate COO data per R + r_vector_coo = defaultdict(lambda: {'data': [], 'rows': [], 'cols': []}) + + for raw_key, small_block in blocks_dict.items(): + key = raw_key.decode() if isinstance(raw_key, (bytes, bytearray)) else str(raw_key) + m = KEY_RE.match(key) + if not m: + # skip unparseable keys + continue + i_site = int(m.group(1)); j_site = int(m.group(2)) + Rx = int(m.group(3)); Ry = int(m.group(4)); Rz = int(m.group(5)) + r_str = f"{Rx}_{Ry}_{Rz}" + + # l-lists + l_lefts = element_l_lists[int(atomic_numbers[i_site])] + l_rights = element_l_lists[int(atomic_numbers[j_site])] + + # get ndarray (support sparse objects) + if hasattr(small_block, "toarray"): + block_arr = small_block.toarray() + elif "torch" in str(type(small_block)): + if small_block.is_cuda: + block_arr = small_block.detach().cpu().numpy() + else: + block_arr = small_block.detach().numpy() + else: + block_arr = np.asarray(small_block) + if block_arr.size == 0: + continue + + # transform DFTIO -> ABACUS + transformed = transform_2_ABACUS(block_arr.astype(np.float32), l_lefts, l_rights) + + # offsets + row_offset = int(site_norbits_cumsum[i_site] - site_norbits[i_site]) + col_offset = int(site_norbits_cumsum[j_site] - site_norbits[j_site]) + + coo = coo_matrix(transformed) + if coo.nnz == 0: + continue + + # apply factor (H vs others) + r_vector_coo[r_str]['data'].append((coo.data.astype(np.float32) / factor)) + r_vector_coo[r_str]['rows'].append((coo.row + row_offset).astype(int)) + r_vector_coo[r_str]['cols'].append((coo.col + col_offset).astype(int)) + + # build final CSR dict + reassembled = {} + for r_str, parts in r_vector_coo.items(): + if not parts['data']: + full = csr_matrix((norbits, norbits), dtype=np.float32) + else: + data = np.concatenate(parts['data']).astype(np.float32) + rows = np.concatenate(parts['rows']).astype(int) + cols = np.concatenate(parts['cols']).astype(int) + full = csr_matrix((data, (rows, cols)), shape=(norbits, norbits)) + reassembled[r_str] = full + + write_abacus_csr_format(reassembled, matrix_symbol, output_path, step=step) + return reassembled, norbits + + +# demo main +if __name__ == "__main__": + LMDB_PATH = r'E:\deeptb\large_DeepTB\0909\0910_lmdb\train\data.28400.lmdb' + ORBITAL_PATH = r'E:\deeptb\basis_set_test\production_use_dzp\orb_upf\public' + + from dprep.dptb_dpdispatcher import parse_orbital_files + _, basis_dict = parse_orbital_files(ORBITAL_PATH) + + env = lmdb.open(LMDB_PATH, readonly=True, lock=False) + with env.begin() as txn: + rec = txn.get((0).to_bytes(length=4, byteorder='big')) + if rec is None: + raise RuntimeError("No record at index 0") + data = pickle.loads(rec) + env.close() + + atomic_numbers = np.array(data['atomic_numbers'], dtype=int) + + if 'hamiltonian' in data and data['hamiltonian']: + write_blocks_to_abacus_csr( + atomic_numbers=atomic_numbers, + basis_dict=basis_dict, + blocks_dict=data['hamiltonian'], + matrix_symbol='H', + output_path='data-HR-sparse_SPIN0.csr', + step=0 + ) + else: + print("No hamiltonian in record 0.") From a24366ce6a9111c4dc488b604e72372ed7a47e16 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Wed, 17 Sep 2025 22:39:59 +0800 Subject: [PATCH 4/7] feat: add MACE-H no-back-grad scale doc_scale_type = ("Which scale method to use. Can be no_scale, " "scale_wo_back_grad (the scale parameter will not engage the back grad computation graph), " "scale_w_back_grad (the scale parameter will engage the back grad computation graph)") --- dptb/nn/rescale.py | 35 +++++++++++++++++++++++++++++------ dptb/utils/argcheck.py | 10 ++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py index 8aca26aa9..8d9351404 100644 --- a/dptb/nn/rescale.py +++ b/dptb/nn/rescale.py @@ -220,6 +220,7 @@ def __init__( shifts_trainable: bool = False, dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), + scale_type: str = 'scale_w_back_grad', **kwargs, ): """Sum edges into nodes.""" @@ -233,6 +234,8 @@ def __init__( self.dtype = dtype self.shift_index = [] self.scale_index = [] + self.scale_type = scale_type + self.scales_trainable = scales_trainable start = 0 start_scalar = 0 @@ -293,7 +296,6 @@ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): self.register_buffer("shifts", shifts) - def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if not (self.has_scales or self.has_shifts): @@ -305,22 +307,31 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: in_field = data[self.field][mask] species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()[mask] - - assert len(in_field) == len( edge_center[mask] ), "in_field doesnt seem to have correct per-edge shape" if self.has_scales: - in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim) + if self.scale_type == 'scale_w_back_grad': + in_field = scales * in_field + elif self.scale_type == 'scale_wo_back_grad': + if self.scales_trainable: + in_field = in_field + in_field.detach() * (scales - 1.0) + else: + in_field = in_field + (in_field * (scales - 1.0)).detach() + else: + raise NotImplementedError + if self.has_shifts: shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] - + data[self.out_field][mask] = in_field return data + class E3PerSpeciesScaleShift(torch.nn.Module): """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. @@ -358,6 +369,7 @@ def __init__( shifts_trainable: bool = False, dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), + scale_type: str = 'scale_w_back_grad', **kwargs, ): super().__init__() @@ -370,6 +382,8 @@ def __init__( self.scale_index = [] self.dtype = dtype self.device = device + self.scale_type = scale_type + self.scales_trainable = scales_trainable start = 0 start_scalar = 0 @@ -442,7 +456,16 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: species_idx ), "in_field doesnt seem to have correct per-atom shape" if self.has_scales: - in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim) + if self.scale_type == 'scale_w_back_grad': + in_field = scales * in_field + elif self.scale_type == 'scale_wo_back_grad': + if self.scales_trainable: + in_field = in_field + in_field.detach() * (scales - 1.0) + else: + in_field = in_field + (in_field * (scales - 1.0)).detach() + else: + raise NotImplementedError if self.has_shifts: shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 7fa3974ae..884c585a0 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -665,19 +665,21 @@ def sktb_prediction(): def e3tb_prediction(): - doc_scales_trainable = "whether to scale the trianing target." - doc_shifts_trainable = "whether to shift the training target." + doc_scales_trainable = "The scale parameter is from the statistics. Whether to train this parameter." + doc_shifts_trainable = "The scale parameter is from the statistics. Whether to train this parameter." doc_neurons = "neurons in the neural network." doc_activation = "activation function." doc_if_batch_normalized = "if to turn on batch normalization" - doc_scale_type = "Which scale method to use. Can be no_scale, scale_wo_back_grad, scale_w_back_grad" + doc_scale_type = ("Which scale method to use. Can be no_scale, " + "scale_wo_back_grad (the scale parameter will not engage the back grad computation graph), " + "scale_w_back_grad (the scale parameter will engage the back grad computation graph)") nn = [ Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable), Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable), Argument("neurons", list, optional=True, default=None, doc=doc_neurons), Argument("activation", str, optional=True, default="tanh", doc=doc_activation), - Argument("scale_type", str, optional=True, default="no_scale", doc=doc_scale_type), + Argument("scale_type", str, optional=True, default="scale_w_back_grad", doc=doc_scale_type), Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized), ] From 6bd8f62486efe16c782312d6b975aff43ac7e594 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:25:28 +0800 Subject: [PATCH 5/7] clean PR with AI suggestion 1. refactored neighbor_list_and_relative_vec to better use ASE API. Local test show the number of bond has not been changed compared with main branch DPTB logic. Show safe implementation of ASE API. (the mask_r logic has been completely removed) 2. adopt AI suggestion for other file changes --- dptb/data/AtomicData.py | 302 ++++++++-------------- dptb/nn/deeptb.py | 2 +- dptb/nnops/loss.py | 26 ++ dptb/postprocess/write_abacus_csr_file.py | 25 +- dptb/utils/tools.py | 2 +- pyproject.toml | 32 +-- 6 files changed, 159 insertions(+), 230 deletions(-) diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py index 5eed91311..f8034514d 100644 --- a/dptb/data/AtomicData.py +++ b/dptb/data/AtomicData.py @@ -15,7 +15,6 @@ from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress -from ase.neighborlist import NewPrimitiveNeighborList from ase.data import chemical_symbols import itertools @@ -885,14 +884,15 @@ def without_nodes(self, which_nodes): assert _ERROR_ON_NO_EDGES in ("true", "false"), "NEQUIP_ERROR_ON_NO_EDGES must be 'true' or 'false'" _ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true" + def neighbor_list_and_relative_vec( - pos, - r_max, - self_interaction=False, - reduce=True, - atomic_numbers=None, - cell=None, - pbc=False, + pos, + r_max, + self_interaction=False, + reduce=True, + atomic_numbers=None, + cell=None, + pbc=False, ): """Create neighbor list and neighbor vectors based on radial cutoff. @@ -906,45 +906,62 @@ def neighbor_list_and_relative_vec( Thus, ``edge_index`` has the same convention as the relative vectors: :math:`\\vec{r}_{source, target}` - If the input positions are a tensor with ``requires_grad == True``, - the output displacement vectors will be correctly attached to the inputs - for autograd. - - All outputs are Tensors on the same device as ``pos``; this allows future - optimization of the neighbor list on the GPU. - Args: pos (shape [N, 3]): Positional coordinate; Tensor or numpy array. If Tensor, must be on CPU. - r_max (float): Radial cutoff distance for neighbor finding. + r_max (float, dict): Radial cutoff distance. Can be a global float, or a dictionary. cell (numpy shape [3, 3]): Cell for periodic boundary conditions. Ignored if ``pbc == False``. - pbc (bool or 3-tuple of bool): Whether the system is periodic in each of the three cell dimensions. - self_interaction (bool): Whether or not to include same periodic image self-edges in the neighbor list. - strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the two - instances of the atom are in different periodic images. Defaults to True, should be True for most applications. + pbc (bool or 3-tuple of bool): Periodic boundary conditions. + self_interaction (bool): Whether to include same periodic image self-edges. + reduce (bool): If True, returns an undirected graph (half edges). If False, returns a full directed graph. + atomic_numbers (array-like): Atomic numbers of the atoms, required if r_max is a dict. Returns: - edge_index (torch.tensor shape [2, num_edges]): List of edges. - edge_cell_shift (torch.tensor shape [num_edges, 3]): Relative cell shift - vectors. Returned only if cell is not None. - cell (torch.Tensor [3, 3]): the cell as a tensor on the correct device. - Returned only if cell is not None. + edge_index (torch.tensor shape [2, num_edges]) + shifts (torch.tensor shape [num_edges, 3]) + cell (torch.Tensor [3, 3]) """ if isinstance(pbc, bool): pbc = (pbc,) * 3 - mask_r = False + atomic_numbers_np = np.asarray(atomic_numbers) if atomic_numbers is not None else None + + # ------------------------------------------------------------------------- + # 1. Parse r_max to ASE-compatible format for native pair-cutoff filtering + # ------------------------------------------------------------------------- if isinstance(r_max, dict): - _r_max = max(r_max.values()) - if _r_max - min(r_max.values()) > 1e-5: - mask_r = True - - if len(r_max) < len(set(atomic_numbers)): + if atomic_numbers_np is None: + raise ValueError("atomic_numbers must be provided when r_max is a dict.") + + if len(r_max) < len(set(atomic_numbers_np)): raise ValueError("The number of r_max is less than the number of required atom species.") + + first_key = next(iter(r_max.keys())) + key_parts = str(first_key).split("-") + + if len(key_parts) == 1: + # Atom-wise cutoffs: ASE naturally handles array input as R[i] + R[j] < cutoff + r_map = get_r_map(r_max, atomic_numbers) + r_map_np = r_map.detach().cpu().numpy() if isinstance(r_map, torch.Tensor) else np.asarray(r_map) + user_cutoff = 0.5 * r_map_np[atomic_numbers_np - 1] + + elif len(key_parts) == 2: + # Pair-wise cutoffs: Convert user string keys to ASE tuple keys + r_map = get_r_map_bondwise(r_max, atomic_numbers) + r_map_np = r_map.detach().cpu().numpy() if isinstance(r_map, torch.Tensor) else np.asarray(r_map) + user_cutoff = {} + unique_nums = np.unique(atomic_numbers_np) + for z1 in unique_nums: + for z2 in unique_nums: + user_cutoff[(int(z1), int(z2))] = float(r_map_np[int(z1) - 1, int(z2) - 1]) + else: + raise ValueError("The r_max keys should be either atomic number or atomic number pair.") else: - _r_max = r_max assert isinstance(r_max, (float, int)) + user_cutoff = float(r_max) - # Either the position or the cell may be on the GPU as tensors + # ------------------------------------------------------------------------- + # 2. Setup Device, Tensors, and Geometry + # ------------------------------------------------------------------------- if isinstance(pos, torch.Tensor): temp_pos = pos.detach().cpu().numpy() out_device = pos.device @@ -954,13 +971,11 @@ def neighbor_list_and_relative_vec( out_device = torch.device("cpu") out_dtype = torch.get_default_dtype() - # Right now, GPU tensors require a round trip if out_device.type != "cpu": warnings.warn( "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible." ) - # Get a cell on the CPU no matter what if isinstance(cell, torch.Tensor): temp_cell = cell.detach().cpu().numpy() cell_tensor = cell.to(device=out_device, dtype=out_dtype) @@ -968,179 +983,74 @@ def neighbor_list_and_relative_vec( temp_cell = np.asarray(cell) cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) else: - # ASE will "complete" this correctly. temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype) cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) - # ASE dependent part temp_cell = ase.geometry.complete_cell(temp_cell) -################################################################################## -###################################### 新代码 ######################################## - elements = np.unique(atomic_numbers).tolist() - pair_cutoffs = {} - for elem1, elem2 in itertools.combinations_with_replacement(elements, 2): - pair_cutoffs[(elem1, elem2)] = max(r_max[chemical_symbols[elem1]], r_max[chemical_symbols[elem2]]) - - nl = NewPrimitiveNeighborList( - cutoffs=10, - skin=0.0, + + # ------------------------------------------------------------------------- + # 3. Call core O(N) neighbor search algorithm + # ------------------------------------------------------------------------- + # By default, primitive_neighbor_list returns a fully directed graph representing + # both (i, j, S) and (j, i, -S). It also automatically removes self-edges (i=i, S=0) + # if self_interaction=False. + first_idx, second_idx, shifts = ase.neighborlist.primitive_neighbor_list( + "ijS", + pbc, + temp_cell, + temp_pos, + cutoff=user_cutoff, + numbers=atomic_numbers_np, self_interaction=self_interaction, - bothways=True, - use_scaled_positions=False + use_scaled_positions=False, ) - nl.cutoffs = pair_cutoffs - nl.update(pbc, temp_cell, temp_pos, atomic_numbers) - first_idex, second_idex, shifts = nl.pair_first, nl.pair_second, nl.offset_vec - mask_r = False - _r_max = max(pair_cutoffs.values()) - - ################################################################################## - ##################################### 老代码 ######################################### - # first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( - # "ijS", - # pbc, - # temp_cell, - # temp_pos, - # cutoff=float(_r_max), - # self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! - # use_scaled_positions=False, - # ) - - ################################################################################## - ################################################################################## - # Eliminate true self-edges that don't cross periodic boundaries - # if not self_interaction: - # bad_edge = first_idex == second_idex - # bad_edge &= np.all(shifts == 0, axis=1) - # keep_edge = ~bad_edge - # if _ERROR_ON_NO_EDGES and (not np.any(keep_edge)): - # raise ValueError( - # f"Every single atom has no neighbors within the cutoff r_max={r_max} (after eliminating self edges, no edges remain in this system)" - # ) - # first_idex = first_idex[keep_edge] - # second_idex = second_idex[keep_edge] - # shifts = shifts[keep_edge] + # ------------------------------------------------------------------------- + # 4. Handle graph reduction state + # ------------------------------------------------------------------------- + if reduce: + # Convert full directed graph to undirected half-graph + mask_lt = first_idx < second_idx + mask_eq = first_idx == second_idx - """ - bond list is: i, j, shift; but i j shift and j i -shift are the same bond. so we need to remove the duplicate bonds.s - first for i != j; we only keep i < j; then the j i -shift will be removed. - then, for i == j; we only keep i i shift and remove i i -shift. - """ - # 1. for i != j, keep i < j - assert atomic_numbers is not None - atomic_numbers = torch.as_tensor(atomic_numbers, dtype=torch.long) - mask = first_idex <= second_idex - first_idex = first_idex[mask] - second_idex = second_idex[mask] - shifts = shifts[mask] - - # 2. for i == j - - mask = torch.ones(len(first_idex), dtype=torch.bool) - mask[first_idex == second_idex] = False - # get index bool type ~mask for i == j. - # Convert mask to numpy for consistent indexing behavior - mask_np = mask.cpu().numpy() - o_first_idex = first_idex[~mask_np] - o_second_idex = second_idex[~mask_np] - o_shift = shifts[~mask_np] - o_mask = mask[~mask] # this is all False, with length being the number all the bonds with i == j. - - # Ensure arrays are proper numpy arrays (not scalars) for isolated systems - o_first_idex = np.atleast_1d(o_first_idex) - o_second_idex = np.atleast_1d(o_second_idex) - o_shift = np.atleast_2d(o_shift) - - # using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict. - rev_dict = {} - for i in range(len(o_first_idex)): - key = str(o_first_idex[i])+str(o_shift[i]) - key_rev = str(o_first_idex[i])+str(-o_shift[i]) - rev_dict[key] = True - # key_rev is the reverse key of key, if key_rev is in the dict, then the bond is duplicate. - # so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True. - if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)): - o_mask[i] = True - - if self_interaction: - log.warning("self_interaction is True, but usually we do not want the self-interaction, please check if it is correct.") - # for self-interaction, the above will remove the self-interaction, i.e. i == j, shift == [0, 0, 0]. since -0 = 0. - if (o_shift[i] == np.array([0, 0, 0])).all(): - o_mask[i] = True - - del rev_dict - del o_first_idex - del o_second_idex - del o_shift - mask[~mask] = o_mask - del o_mask - - # Convert mask to numpy for indexing numpy arrays (avoids torch/numpy compatibility issues) - mask_np = mask.cpu().numpy() - first_idex = torch.as_tensor(first_idex[mask_np], dtype=torch.long, device=out_device) - second_idex = torch.as_tensor(second_idex[mask_np], dtype=torch.long, device=out_device) - shifts = torch.as_tensor(shifts[mask_np], dtype=out_dtype, device=out_device) - - if not reduce: - assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated." - first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0) - shifts = torch.cat((shifts, -shifts), dim=0) - - # Build output: - edge_index = torch.vstack( - (torch.LongTensor(first_idex), torch.LongTensor(second_idex)) - ) + # Deduplicate mirrored periodic boundaries for i == j + eq_first = first_idx[mask_eq] + eq_shifts = shifts[mask_eq] + eq_keep = np.zeros(len(eq_first), dtype=bool) - # TODO: mask the edges that is larger than r_max - if mask_r: - edge_vec = pos[edge_index[1]] - pos[edge_index[0]] - if cell is not None : - edge_vec = edge_vec + torch.einsum( - "ni,ij->nj", - shifts, - cell_tensor.reshape(3,3), # remove batch dimension - ) + rev_dict = {} + for i in range(len(eq_first)): + key = f"{eq_first[i]}_{eq_shifts[i]}" + key_rev = f"{eq_first[i]}_{-eq_shifts[i]}" + rev_dict[key] = True - edge_length = torch.linalg.norm(edge_vec, dim=-1) + if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)): + eq_keep[i] = True - # atom_species_num = [atomic_num_dict[k] for k in r_max.keys()] - # for i in set(atomic_numbers): - # assert i in atom_species_num - # r_map = torch.zeros(max(atom_species_num)) - # for k, v in r_max.items(): - # r_map[atomic_num_dict[k]-1] = v + if self_interaction and (eq_shifts[i] == 0).all(): + eq_keep[i] = True - first_key = next(iter(r_max.keys())) - key_parts = first_key.split("-") - - if len(key_parts)==1: - r_map = get_r_map(r_max, atomic_numbers) - edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1]) - - elif len(key_parts)==2: - r_map = get_r_map_bondwise(r_max, atomic_numbers) - edge_length_max = r_map[atomic_numbers[edge_index[0]]-1,atomic_numbers[edge_index[1]]-1] - else: - raise ValueError("The r_max keys should be either atomic number or atomic number pair.") - - r_mask = edge_length <= edge_length_max - if any(~r_mask): - edge_index = edge_index[:, r_mask] - shifts = shifts[r_mask] - # 收集不同类型的边及其对应的最大截断半径 - #edge_types = {} - #for i in range(edge_index.shape[1]): - # atom_type_pair = (atomic_numbers[edge_index[0, i]], atomic_numbers[edge_index[1, i]]) - # if atom_type_pair not in edge_types: - # edge_types[atom_type_pair] = edge_length_max[i].item() - - del edge_length - del edge_vec - del r_map - del edge_length_max - del r_mask + # Combine reduction masks + final_mask = mask_lt.copy() + final_mask[mask_eq] = eq_keep + + first_idx = first_idx[final_mask] + second_idx = second_idx[final_mask] + shifts = shifts[final_mask] + + # Note: If `reduce=False`, the output of primitive_neighbor_list is exactly the + # full bidirectional graph structure required, so no post-processing is needed. + + # ------------------------------------------------------------------------- + # 5. Build output tensors + # ------------------------------------------------------------------------- + first_idx_t = torch.as_tensor(first_idx, dtype=torch.long, device=out_device) + second_idx_t = torch.as_tensor(second_idx, dtype=torch.long, device=out_device) + shifts_t = torch.as_tensor(shifts, dtype=out_dtype, device=out_device) + + edge_index = torch.vstack((first_idx_t, second_idx_t)) + + return edge_index, shifts_t, cell_tensor - return edge_index, shifts, cell_tensor def get_r_map(r_max: dict, atomic_numbers=None): """ diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 7bdd481f8..a17278459 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -300,7 +300,7 @@ def forward(self, data: AtomicDataDict.Type): if hasattr(self, "overlap") and self.method == "sktb": data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY] - if self.scale_type != 'no_scale': + if self.method != "e3tb" or self.scale_type != "no_scale": data = self.node_prediction_h(data) data = self.edge_prediction_h(data) diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py index 055987ee2..7c0b92a4f 100644 --- a/dptb/nnops/loss.py +++ b/dptb/nnops/loss.py @@ -728,6 +728,32 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): pre = torch.cat([pre_onsite, pre_hopping], dim=0) tgt = torch.cat([tgt_onsite, tgt_hopping], dim=0) + # ================= 新增:overlap loss 逻辑 ================= + if self.overlap: + # onsite overlap + pre_onsite_ovlp = data[AtomicDataDict.NODE_OVERLAP_KEY][ + self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + tgt_onsite_ovlp = ref_data[AtomicDataDict.NODE_OVERLAP_KEY][ + self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + + # hopping overlap + pre_hopping_ovlp = data[AtomicDataDict.EDGE_OVERLAP_KEY][ + self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + tgt_hopping_ovlp = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][ + self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + + pre_ovlp = torch.cat([pre_onsite_ovlp, pre_hopping_ovlp], dim=0) + tgt_ovlp = torch.cat([tgt_onsite_ovlp, tgt_hopping_ovlp], dim=0) + + # 将 overlap 特征拼接到 pre/tgt 中一同计算 MAE + pre = torch.cat([pre, pre_ovlp], dim=0) + tgt = torch.cat([tgt, tgt_ovlp], dim=0) + # ========================================================== + total_loss = self.loss1(pre, tgt) return total_loss diff --git a/dptb/postprocess/write_abacus_csr_file.py b/dptb/postprocess/write_abacus_csr_file.py index d8c8aae44..43171b4a3 100644 --- a/dptb/postprocess/write_abacus_csr_file.py +++ b/dptb/postprocess/write_abacus_csr_file.py @@ -103,11 +103,24 @@ def write_blocks_to_abacus_csr(atomic_numbers, basis_dict, blocks_dict, matrix_s element_l_lists = {} for Z in np.unique(atomic_numbers): basis_str = find_basis_for_Z_or_symbol(basis_dict, int(Z)) + + # 1. 检测基组是否缺失 if basis_str is None: - element_l_lists[int(Z)] = [0] + raise ValueError( + f"Matrix '{matrix_symbol}': find_basis_for_Z_or_symbol() could not find a basis for Z={Z}. " + f"Available keys in basis_dict: {list(basis_dict.keys())}. " + f"Aborting to prevent silent downstream dimension errors in element_l_lists." + ) else: ll = parse_basis_to_l_list(basis_str) - element_l_lists[int(Z)] = ll if ll else [0] + # 2. 检测基组字符串是否解析为空 + if not ll: + raise ValueError( + f"Matrix '{matrix_symbol}': parse_basis_to_l_list() returned an empty list " + f"for basis string '{basis_str}' (Z={Z}). " + f"Aborting to prevent silent downstream dimension errors in element_l_lists." + ) + element_l_lists[int(Z)] = ll # site norbits site_norbits = np.array([sum(2 * l + 1 for l in element_l_lists[int(Z)]) for Z in atomic_numbers], dtype=int) @@ -123,8 +136,11 @@ def write_blocks_to_abacus_csr(atomic_numbers, basis_dict, blocks_dict, matrix_s if not m: # skip unparseable keys continue - i_site = int(m.group(1)); j_site = int(m.group(2)) - Rx = int(m.group(3)); Ry = int(m.group(4)); Rz = int(m.group(5)) + i_site = int(m.group(1)) + j_site = int(m.group(2)) + Rx = int(m.group(3)) + Ry = int(m.group(4)) + Rz = int(m.group(5)) r_str = f"{Rx}_{Ry}_{Rz}" # l-lists @@ -175,7 +191,6 @@ def write_blocks_to_abacus_csr(atomic_numbers, basis_dict, blocks_dict, matrix_s write_abacus_csr_format(reassembled, matrix_symbol, output_path, step=step) return reassembled, norbits - # demo main if __name__ == "__main__": LMDB_PATH = r'E:\deeptb\large_DeepTB\0909\0910_lmdb\train\data.28400.lmdb' diff --git a/dptb/utils/tools.py b/dptb/utils/tools.py index d276be9f7..2d6caeadd 100644 --- a/dptb/utils/tools.py +++ b/dptb/utils/tools.py @@ -145,7 +145,7 @@ def get_optimizer(type: str, model_param, lr: float, **options: dict): elif type == 'LBFGS': optimizer = optim.LBFGS(params=model_param, lr=lr, **options) else: - raise RuntimeError("Optimizer should be Adam/AdamW/SGD/RMSprop, not {}".format(type)) + raise RuntimeError("Optimizer should be Adam/AdamW/SGD/RMSprop/LBFGS, not {}".format(type)) return optimizer def get_lr_scheduler(type: str, optimizer: optim.Optimizer, **sch_options): diff --git a/pyproject.toml b/pyproject.toml index 876c62b56..927458365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,14 +28,10 @@ dependencies = [ "scipy>=1.11,<1.13", "spglib", "matplotlib", - "torch>=2.0.0,<=2.5.1", "ase", "pyyaml", "dargs==0.4.4", "e3nn>=0.5.1", - "torch-runstats==0.2.0", - "torch_scatter==2.1.2", - "torch_geometric>=2.4.0", "opt-einsum==3.3.0", "h5py>=3.7.0,<3.12,!=3.10.0", "lmdb==1.4.1", @@ -45,29 +41,11 @@ dependencies = [ "rich>=13.0.0", ] -[tool.poetry.group.dev.dependencies] -pytest = ">=7.2.0" -pytest-order = "1.2.0" -numpy = "*" -scipy = ">=1.11.*,<=1.12.*" -spglib = "*" -matplotlib = "*" -torch = ">=2.0.0,<=2.5.1" -ase = "*" -pyyaml = "*" -future = "*" -dargs = "0.4.4" -xitorch = "0.3.0" -e3nn = ">=0.5.1,<=0.5.3" -torch-runstats = "0.2.0" -torch_scatter = "2.1.2" -torch_geometric = ">=2.4.0" -opt-einsum = "3.3.0" -h5py = ">=3.7.0,<=3.11.0,!=3.10.0" -lmdb = "1.4.1" -pyfiglet = "1.0.2" -tensorboard = "*" -seekpath = "*" +[project.optional-dependencies] +3Dfermi = ["ifermi", "pymatgen"] +tbtrans_init = ["sisl"] +pybinding = ["pybinding"] +pythtb = ["pythtb"] [project.scripts] dptb = "dptb.__main__:main" From dbd42afc5e94c798b984eabe1a36a125b20e0a18 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:34:55 +0800 Subject: [PATCH 6/7] Update pyproject.toml reverse debug setup --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 927458365..4a3827a55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,10 +28,14 @@ dependencies = [ "scipy>=1.11,<1.13", "spglib", "matplotlib", + "torch>=2.0.0,<=2.5.1", "ase", "pyyaml", "dargs==0.4.4", "e3nn>=0.5.1", + "torch-runstats==0.2.0", + "torch_scatter==2.1.2", + "torch_geometric>=2.4.0", "opt-einsum==3.3.0", "h5py>=3.7.0,<3.12,!=3.10.0", "lmdb==1.4.1", From 619d06697145bd9b17b5892144ad9a9984108cc2 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Fri, 6 Mar 2026 21:04:50 +0800 Subject: [PATCH 7/7] fix: adapt AI suggestion --- dptb/data/AtomicData.py | 21 ++++++++---- dptb/nnops/loss.py | 41 ++++------------------- dptb/postprocess/write_abacus_csr_file.py | 4 ++- dptb/utils/argcheck.py | 2 ++ 4 files changed, 26 insertions(+), 42 deletions(-) diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py index f8034514d..4d0a390f3 100644 --- a/dptb/data/AtomicData.py +++ b/dptb/data/AtomicData.py @@ -923,7 +923,14 @@ def neighbor_list_and_relative_vec( if isinstance(pbc, bool): pbc = (pbc,) * 3 - atomic_numbers_np = np.asarray(atomic_numbers) if atomic_numbers is not None else None + # 【优化】:确保如果输入的是 GPU Tensor,能够安全转换 + if atomic_numbers is not None: + if isinstance(atomic_numbers, torch.Tensor): + atomic_numbers_np = atomic_numbers.detach().cpu().numpy() + else: + atomic_numbers_np = np.asarray(atomic_numbers) + else: + atomic_numbers_np = None # ------------------------------------------------------------------------- # 1. Parse r_max to ASE-compatible format for native pair-cutoff filtering @@ -940,13 +947,13 @@ def neighbor_list_and_relative_vec( if len(key_parts) == 1: # Atom-wise cutoffs: ASE naturally handles array input as R[i] + R[j] < cutoff - r_map = get_r_map(r_max, atomic_numbers) + r_map = get_r_map(r_max, atomic_numbers_np) r_map_np = r_map.detach().cpu().numpy() if isinstance(r_map, torch.Tensor) else np.asarray(r_map) user_cutoff = 0.5 * r_map_np[atomic_numbers_np - 1] elif len(key_parts) == 2: # Pair-wise cutoffs: Convert user string keys to ASE tuple keys - r_map = get_r_map_bondwise(r_max, atomic_numbers) + r_map = get_r_map_bondwise(r_max, atomic_numbers_np) r_map_np = r_map.detach().cpu().numpy() if isinstance(r_map, torch.Tensor) else np.asarray(r_map) user_cutoff = {} unique_nums = np.unique(atomic_numbers_np) @@ -976,18 +983,20 @@ def neighbor_list_and_relative_vec( "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible." ) + # 获取初始 cell 数据 if isinstance(cell, torch.Tensor): temp_cell = cell.detach().cpu().numpy() - cell_tensor = cell.to(device=out_device, dtype=out_dtype) elif cell is not None: temp_cell = np.asarray(cell) - cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) else: temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype) - cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) + # ASE 补全缺失的晶格向量 temp_cell = ase.geometry.complete_cell(temp_cell) + # 【修复2】:在此处(补全后)生成 cell_tensor,保证输出与 shifts 处于同一坐标系参考标准下 + cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) + # ------------------------------------------------------------------------- # 3. Call core O(N) neighbor search algorithm # ------------------------------------------------------------------------- diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py index 7c0b92a4f..6b3daec77 100644 --- a/dptb/nnops/loss.py +++ b/dptb/nnops/loss.py @@ -672,42 +672,13 @@ def __init__( self.idp = idp def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): - # mask the data - - # data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.) - # data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.) + # ================= 修复 CodeRabbit 审查意见 ================= if self.onsite_shift: - batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0])) - # assert batch.max() == 0, "The onsite shift is only supported for batchsize=1." - mu = data[AtomicDataDict.NODE_FEATURES_KEY][ - self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \ - ref_data[AtomicDataDict.NODE_FEATURES_KEY][ - self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - if batch.max() == 0: # when batchsize is zero - mu = mu.mean().detach() - ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[ - AtomicDataDict.NODE_OVERLAP_KEY] - ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[ - AtomicDataDict.EDGE_OVERLAP_KEY] - elif batch.max() >= 1: - slices = [data["__slices__"]["pos"][i] - data["__slices__"]["pos"][i - 1] for i in - range(1, len(data["__slices__"]["pos"]))] - slices = [0] + slices - ndiag_batch = torch.stack([i.sum() for i in - self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split( - slices)]) - ndiag_batch = torch.cumsum(ndiag_batch, dim=0) - mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i + 1]].mean() for i in range(len(ndiag_batch) - 1)]) - mu = mu.detach() - ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[ - batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY] - edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, - device=self.device) - for i in range(1, batch.max().item() + 1): - edge_mu_index[data["__slices__"]["edge_index"][i]:data["__slices__"]["edge_index"][i + 1]] += i - ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu[ - edge_mu_index, None] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY] + # 直接复用统一的 shift_mu 函数,保持与 HamilLossAbs/EigHamLoss 语义对齐 + # 内部会同时利用 node 和 edge 的 overlap 贡献综合推导并应用 mu + shift_mu(data, ref_data, self.idp) + # ============================================================ # onsite loss pre_onsite = data[AtomicDataDict.NODE_FEATURES_KEY][ @@ -728,7 +699,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): pre = torch.cat([pre_onsite, pre_hopping], dim=0) tgt = torch.cat([tgt_onsite, tgt_hopping], dim=0) - # ================= 新增:overlap loss 逻辑 ================= + # ================= 保留 overlap loss 逻辑 ================= if self.overlap: # onsite overlap pre_onsite_ovlp = data[AtomicDataDict.NODE_OVERLAP_KEY][ diff --git a/dptb/postprocess/write_abacus_csr_file.py b/dptb/postprocess/write_abacus_csr_file.py index 43171b4a3..ff83d3521 100644 --- a/dptb/postprocess/write_abacus_csr_file.py +++ b/dptb/postprocess/write_abacus_csr_file.py @@ -21,9 +21,11 @@ def parse_basis_to_l_list(basis_str): """'2s2p1d' or 'spd' -> [0,0,1,1,2].""" if basis_str is None: return [] - s = str(basis_str).strip().lower() + s = re.sub(r"\s+", "", str(basis_str).lower()) if s == "": return [] + if not re.fullmatch(r"(?:\d*[spdfgh])+", s): + raise ValueError(f"Invalid basis string '{basis_str}'") tokens = re.findall(r'(\d*)([spdfgh])', s) lst = [] for num, ch in tokens: diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 66f247e7d..53280b74c 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -837,6 +837,7 @@ def loss_options(): - `eigvals`: The mse loss predicted and labeled eigenvalues and Delta eigenvalues between different k. - `hamil`: - `hamil_abs`: + - `hamil_abs_mae`: - `hamil_blas`: """ doc_train = "Loss options for training." @@ -874,6 +875,7 @@ def loss_options(): Argument("eigvals", dict, sub_fields=eigvals), Argument("skints", dict, sub_fields=skints), Argument("hamil_abs", dict, sub_fields=hamil), + Argument("hamil_abs_mae", dict, sub_fields=hamil), Argument("hamil_blas", dict, sub_fields=hamil), Argument("hamil_wt", dict, sub_fields=hamil+wt), Argument("eig_ham", dict, sub_fields=hamil+eigvals+eig_ham),