diff --git a/.gitignore b/.gitignore index 01cde404..6b62d10b 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ uv.lock # duecredit .duecredit.p + +# ignore local users potential agent files/plans +.agents/ diff --git a/examples/readme.md b/examples/readme.md index 7e78979e..06f778b3 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -35,7 +35,6 @@ If you'd like to execute the scripts or examples locally, you can run them with: curl -LsSf https://astral.sh/uv/install.sh | sh # pick any of the examples -uv run --with . examples/2_Structural_optimization/2.3_MACE_FIRE.py -uv run --with . examples/3_Dynamics/3.3_MACE_NVE_cueq.py -uv run --with . examples/4_High_level_api/4.1_high_level_api.py +uv run --with . examples/scripts/1_introduction.py +uv run --with . examples/scripts/2_structural_optimization.py ``` diff --git a/examples/scripts/9_neb.py b/examples/scripts/9_neb.py new file mode 100644 index 00000000..aa26eff7 --- /dev/null +++ b/examples/scripts/9_neb.py @@ -0,0 +1,934 @@ +"""Nudged Elastic Band (NEB) workflow. + +This script demonstrates the Nudged Elastic Band method for finding minimum energy +paths between two given atomic configurations. +""" +# %% +# /// script +# dependencies = [ +# "mace-torch>=0.3.12", +# "ase", +# ] +# /// + +import json # Import json for output + +# Configure logging to DEBUG level first +import logging +import pickle # Import pickle + +import ase.geometry # Import the geometry module +import h5py +import matplotlib.pyplot as plt +import numpy as np +import torch +from ase.build import bulk +from ase.io import read +from ase.mep import NEB as ASENEB +from ase.mep.neb import ImprovedTangentMethod, NEBState +from ase.optimize import FIRE +from mace.calculators.foundations_models import mace_mp +from mace.calculators.mace import MACECalculator +from monty.json import MontyDecoder, MontyEncoder # Import Monty + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.state import SimState +from torch_sim.workflows.neb import NEB as TorchNEB + + +# Redirect logging to a file instead of stdout +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(message)s", + filename="neb_debug.log", # Specify the log file name + filemode="w", +) # Overwrite the log file each time +logging.getLogger("torch_sim.workflows.neb").setLevel(logging.DEBUG) + + +torch_sim_device = "cuda" if torch.cuda.is_available() else "cpu" +torch_sim_dtype = torch.float64 # Use float64 for higher precision + +# Load MACE model using mace_mp like other tutorials +print("Loading MACE model...") +mace_potential = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(torch_sim_dtype).removeprefix("torch."), + device=str(torch_sim_device), +) + + +def compare_initial_paths( + ase_start_atoms, + ase_end_atoms, + torch_sim_initial_state: SimState, + torch_sim_final_state: SimState, + neb_workflow: TorchNEB, +): + """Compares initial paths and the MIC displacement vector.""" + print("Comparing initial interpolated paths and MIC vectors...") + n_images = neb_workflow.n_images + n_total_images = n_images + 2 + device = neb_workflow.device + dtype = neb_workflow.dtype + + # --- Endpoint Check --- + print("\nChecking consistency of starting endpoint positions:") + ase_start_pos_direct = ase_start_atoms.get_positions() + ts_start_pos_direct = torch_sim_initial_state.positions.cpu().numpy() + start_close = np.allclose( + ase_start_pos_direct, ts_start_pos_direct, rtol=1e-5, atol=1e-6 + ) + print(f" Direct Start positions close: {start_close}") + if not start_close: + max_diff_start = np.max(np.abs(ase_start_pos_direct - ts_start_pos_direct)) + print(f" Max absolute difference (Start): {max_diff_start:.6f}") + print("------------------------------------") + + # --- MIC Vector Comparison --- + print("\nComparing Minimum Image Convention (MIC) displacement vectors:") + # Use the torch-sim states as the source of truth for positions/cell + raw_dr_ts = torch_sim_final_state.positions - torch_sim_initial_state.positions + cell_ts = torch_sim_initial_state.cell[0] # Assuming single batch cell + pbc_ts = torch_sim_initial_state.pbc + + # ASE MIC calculation + try: + ase_cell_np = cell_ts.cpu().numpy() + ase_pbc_np = np.array([pbc_ts] * 3) # ASE expects 3 bools usually + ase_mic_dr_np, _ = ase.geometry.find_mic( + raw_dr_ts.cpu().numpy(), ase_cell_np, pbc=ase_pbc_np + ) + print(f" ASE MIC vector calculated (shape: {ase_mic_dr_np.shape})") + except Exception as e: + print(f" Error calculating ASE MIC: {e}") + ase_mic_dr_np = None + + # torch-sim MIC calculation + try: + ts_mic_dr = ts.transforms.minimum_image_displacement( + dr=raw_dr_ts, cell=cell_ts, pbc=pbc_ts + ) + ts_mic_dr_np = ts_mic_dr.cpu().numpy() + print(f" torch-sim MIC vector calculated (shape: {ts_mic_dr_np.shape})") + except Exception as e: + print(f" Error calculating torch-sim MIC: {e}") + ts_mic_dr_np = None + + # Compare the MIC vectors + if ase_mic_dr_np is not None and ts_mic_dr_np is not None: + if ase_mic_dr_np.shape != ts_mic_dr_np.shape: + print(" Error: Shapes of MIC vectors do not match.") + else: + mic_vectors_close = np.allclose( + ase_mic_dr_np, ts_mic_dr_np, rtol=1e-5, atol=1e-6 + ) + print(f" MIC displacement vectors close: {mic_vectors_close}") + if not mic_vectors_close: + max_diff_mic = np.max(np.abs(ase_mic_dr_np - ts_mic_dr_np)) + norm_diff = np.linalg.norm(ase_mic_dr_np - ts_mic_dr_np) + print(f" Max absolute difference (MIC vectors): {max_diff_mic:.6f}") + print(f" Norm of difference vector (MIC): {norm_diff:.6f}") + print(" This difference likely causes the interpolation discrepancy.") + print("------------------------------------") + + # --- Get ASE interpolated path --- + ase_images = [ase_start_atoms.copy() for _ in range(n_images + 1)] + ase_images.append(ase_end_atoms.copy()) + ase_neb_calc = ASENEB(ase_images, climb=False) + ase_neb_calc.interpolate(mic=True) + ase_positions = np.stack([img.get_positions() for img in ase_neb_calc.images]) + print(f"\n ASE interpolated path shape: {ase_positions.shape}") + + # --- Get torch-sim interpolated path --- + try: + interpolated_state = neb_workflow._interpolate_path( + torch_sim_initial_state, torch_sim_final_state + ) + ts_interp_pos = interpolated_state.positions + ts_start_pos = torch_sim_initial_state.positions + ts_end_pos = torch_sim_final_state.positions + n_atoms = ts_start_pos.shape[0] + ts_interp_pos_reshaped = ts_interp_pos.reshape(n_images, n_atoms, 3) + ts_positions = torch.cat( + [ + torch_sim_initial_state.positions.unsqueeze(0).to(device, dtype), + ts_interp_pos_reshaped.to(device, dtype), + torch_sim_final_state.positions.unsqueeze(0).to(device, dtype), + ], + dim=0, + ) + ts_positions_np = ts_positions.cpu().numpy() + print(f" torch-sim interpolated path shape (direct): {ts_positions_np.shape}") + except Exception as e: + print(f" Error during torch-sim interpolation: {e}") + import traceback + + traceback.print_exc() + return + + # --- Compare Interpolated Paths --- + print( + "\n Per-image comparison of interpolated paths (Max Abs Error | Mean Abs Error):" + ) + overall_max_diff_interp = 0.0 + if ase_positions.shape != ts_positions_np.shape: + print(" Error: Shapes of ASE and torch-sim interpolated paths do not match.") + return + + for i in range(n_total_images): + diff_image_i = np.abs(ase_positions[i] - ts_positions_np[i]) + max_ae_i = np.max(diff_image_i) + mae_i = np.mean(diff_image_i) + print(f" Image {i}: MaxAE = {max_ae_i:.6f} | MAE = {mae_i:.6f}") + overall_max_diff_interp = max(overall_max_diff_interp, max_ae_i) + + are_close_interp = np.allclose(ase_positions, ts_positions_np, rtol=1e-5, atol=1e-6) + + if are_close_interp: + print(" Overall: Interpolated paths are numerically close.") + else: + print(" Overall: Interpolated paths differ numerically.") + print( + f" Overall Maximum absolute difference (Interpolated): {overall_max_diff_interp:.6f}" + ) + + +def ase_neb(start_atoms, end_atoms, nimages=5): + device = "cuda" if torch.cuda.is_available() else "cpu" + images = [start_atoms.copy() for _ in range(nimages + 1)] + images.append(end_atoms.copy()) + + neb_calc = ASENEB(images, climb=True, method="improvedtangent") + neb_calc.interpolate(mic=True) + + # Attach calculator to all images using mace_mp + ase_dtype_str = "float64" if torch_sim_dtype == torch.float64 else "float32" + print(f"Attaching ASE calculator with dtype: {ase_dtype_str} to all images") + ase_calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=ase_dtype_str, + dispersion=False, + ) + for image in neb_calc.images: + image.calc = ase_calc + + # Set up trajectory logging for the reference ASE run (Commented out as not used for plot) + # ase_traj_filename = "ase_ref_neb.traj" + opt = FIRE(neb_calc) + # opt.attach(traj) # Attach the trajectory logger + + # Run the ASE optimization (essential) + print("Running ASE NEB optimization...") + opt.run(fmax=0.05, steps=1000) + print("Finished ASE NEB optimization.") + + return neb_calc # Only return the final NEB object + + +def relax_atoms( + atoms, + fmax=0.05, + steps=1000, + device=torch_sim_device, + dtype=torch_sim_dtype, +): + new_atoms = atoms.copy() + ase_dtype_str = "float64" if dtype == torch.float64 else "float32" + new_atoms.calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=str(device), + default_dtype=ase_dtype_str, + dispersion=False, + ) + opt = FIRE(new_atoms) + opt.run(fmax=fmax, steps=steps) + return new_atoms + + +# Create the torch_sim wrapper +ts_mace_model = MaceModel( + model=mace_potential, + device=torch_sim_device, + dtype=torch_sim_dtype, + compute_forces=True, # Default, but good to be explicit + compute_stress=True, # Needed by interface if we want stress later + enable_cueq=False, +) + +# initial_trajectory = read('/home/myless/Packages/forge/scratch/data/neb_workflow_data/Cr7Ti8V104W8Zr_Cr_to_V_site102_to_69_initial.xyz', index=':') +# print(len(initial_trajectory)) + +# Create simple test structures for demonstration +# Using bulk structures instead of file paths +# Create simple test structures (can be replaced with file reads if needed) +start_atoms = bulk("Al", "fcc", a=4.05, cubic=True).repeat((2, 2, 2)) +end_atoms = bulk("Al", "fcc", a=4.05, cubic=True).repeat((2, 2, 2)) +# Add a small displacement to create a path +end_atoms.positions[0] += [0.1, 0.1, 0.1] + +relaxed_start_atoms = relax_atoms(start_atoms) +relaxed_end_atoms = relax_atoms(end_atoms) + +traj_file_name = "neb_path_torchsim_fire_5im.hdf5" + +# --- Setup ASE NEB for comparison --- +n_intermediate_images_ase = 5 +ase_images_compare = [relaxed_start_atoms.copy()] +ase_images_compare.extend( + [relaxed_start_atoms.copy() for _ in range(n_intermediate_images_ase)] +) +ase_images_compare.append(relaxed_end_atoms.copy()) + +ase_neb_compare = ASENEB( + ase_images_compare, + k=0.1, # Match torch-sim spring constant + climb=True, # Match torch-sim setting + method="improvedtangent", # Match torch-sim tangent method +) +ase_neb_compare.interpolate(mic=True) # Initial interpolation + +device = "cuda" if torch.cuda.is_available() else "cpu" +# Attach calculator to ALL ASE images using mace_mp +ase_dtype_str_compare = "float64" if torch_sim_dtype == torch.float64 else "float32" +print(f"Using ASE comparison calculator dtype: {ase_dtype_str_compare}") +ase_calculator = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=ase_dtype_str_compare, + dispersion=False, +) +for img in ase_neb_compare.images: + img.calc = ase_calculator +# ---------------------------------- + +initial_system = ts.io.atoms_to_state( + relaxed_start_atoms.copy(), device=torch_sim_device, dtype=torch_sim_dtype +) +final_system = ts.io.atoms_to_state( + relaxed_end_atoms.copy(), device=torch_sim_device, dtype=torch_sim_dtype +) + +neb_workflow = TorchNEB( + model=ts_mace_model, + device=torch_sim_device, + dtype=torch_sim_dtype, + spring_constant=0.1, + n_images=5, + use_climbing_image=True, # Set as desired for the actual run + optimizer_type="ase_fire", # Set as desired for the actual run + optimizer_params={}, + trajectory_filename=traj_file_name, +) + +compare_initial_paths( + relaxed_start_atoms, relaxed_end_atoms, initial_system, final_system, neb_workflow +) + + +# --- Add Function for Manual ASE Force Calculation --- +def calculate_ase_neb_force_step0( + ase_neb_calc: ASENEB, + image_index: int, + neb_workflow: TorchNEB, + output_filename="ase_step0_debug.json", +): + """Manually calculates the ASE NEB force components for a specific + intermediate image at step 0 (after initial interpolation) and saves + the results to a JSON file. + Uses the ImprovedTangent method for consistency with torch-sim default. + """ + print(f"--- Calculating ASE NEB Debug Info (Step 0, Image Index {image_index}) ---") + debug_data = { + "step": 0, + "image_index_intermediate": image_index - 1, # 0-based index among intermediates + "image_index_absolute": image_index, # 0-based index in full list + "inputs": {}, + "outputs": {}, + "error": None, + } + + n_images = ase_neb_calc.nimages # Total number of images including endpoints + if not (0 < image_index < n_images - 1): + error_msg = f"Error: image_index {image_index} is not an intermediate image." + print(error_msg) + debug_data["error"] = error_msg + with open(output_filename, "w") as f: + json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder + return + + # 1. Get initial energies and forces after interpolation + calculator attachment + try: + initial_energies_np = np.array( + [img.get_potential_energy() for img in ase_neb_calc.images] + ) + initial_forces_np = np.stack([img.get_forces() for img in ase_neb_calc.images]) + + # No need for .tolist() with MontyEncoder + debug_data["inputs"]["energies_all"] = initial_energies_np + debug_data["inputs"]["true_forces_image"] = initial_forces_np[image_index] + debug_data["inputs"]["positions_image_minus_1"] = ase_neb_calc.images[ + image_index - 1 + ].get_positions() + debug_data["inputs"]["positions_image"] = ase_neb_calc.images[ + image_index + ].get_positions() + debug_data["inputs"]["positions_image_plus_1"] = ase_neb_calc.images[ + image_index + 1 + ].get_positions() + debug_data["inputs"]["cell"] = ( + ase_neb_calc.images[image_index].get_cell().tolist() + ) + # No need for bool() conversion with MontyEncoder + debug_data["inputs"]["pbc"] = ase_neb_calc.images[image_index].get_pbc() + + except Exception as e: + error_msg = f"Error getting initial energies/forces from ASE images: {e}" + print(error_msg) + debug_data["error"] = error_msg + import traceback + + debug_data["traceback"] = traceback.format_exc() + with open(output_filename, "w") as f: + json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder + return + + # 2. Setup NEB state and method objects + ase_neb_obj_for_state = ASENEB( + ase_neb_calc.images, + k=neb_workflow.spring_constant, + climb=neb_workflow.use_climbing_image, + method="improvedtangent", + ) + neb_state = NEBState(ase_neb_obj_for_state, ase_neb_calc.images, initial_energies_np) + tangent_method = ImprovedTangentMethod(ase_neb_obj_for_state) + + # 3. Calculate components for the target image_index + try: + spring1 = neb_state.spring(image_index - 1) + spring2 = neb_state.spring(image_index) + # No .tolist() needed + debug_data["outputs"]["mic_displacement_1"] = spring1.t + debug_data["outputs"]["mic_displacement_2"] = spring2.t + + # Calculate tangent + tangent_ase = tangent_method.get_tangent(neb_state, spring1, spring2, image_index) + tangent_norm_ase = np.linalg.norm(tangent_ase) + if tangent_norm_ase > 1e-15: + tangent_ase_normalized = tangent_ase / tangent_norm_ase + else: + tangent_ase_normalized = tangent_ase # Keep as zero vector + tangent_norm_final = np.linalg.norm(tangent_ase_normalized) + + # No .tolist() needed + debug_data["outputs"]["tangent_vector"] = tangent_ase_normalized + debug_data["outputs"]["tangent_norm"] = tangent_norm_final + + # Calculate perpendicular force + true_force_img = initial_forces_np[image_index] + f_true_dot_tau_ase = np.vdot(true_force_img, tangent_ase_normalized) + f_perp_ase = true_force_img - f_true_dot_tau_ase * tangent_ase_normalized + f_perp_norm = np.linalg.norm(f_perp_ase) + + # No .tolist() needed + debug_data["outputs"]["f_true_dot_tau"] = f_true_dot_tau_ase + debug_data["outputs"]["f_perp_vector"] = f_perp_ase + debug_data["outputs"]["f_perp_norm"] = f_perp_norm + + # Calculate parallel spring force + segment_lengths_all = [neb_state.spring(i).nt for i in range(n_images - 1)] + spring_mag_term = spring2.nt * spring2.k - spring1.nt * spring1.k + f_spring_par_ase = spring_mag_term * tangent_ase_normalized + f_spring_par_norm = np.linalg.norm(f_spring_par_ase) + + # No .tolist() needed + debug_data["outputs"]["segment_lengths"] = segment_lengths_all + debug_data["outputs"]["spring_force_magnitude_term"] = spring_mag_term + debug_data["outputs"]["f_spring_par_vector"] = f_spring_par_ase + debug_data["outputs"]["f_spring_par_norm"] = f_spring_par_norm + + # Calculate total NEB force (before potential climbing modification) + neb_force_ase = f_perp_ase + f_spring_par_ase + # Explicitly convert to numpy array before saving, remove .tolist() + debug_data["outputs"]["neb_force_before_climb_vector"] = np.array(neb_force_ase) + debug_data["outputs"]["neb_force_before_climb_norm"] = np.linalg.norm( + neb_force_ase + ) + + # --- Direct Debug Prints for Step 0 --- + print("\n --- DIRECT DEBUG PRINT (ASE STEP 0) ---") + print(f" f_perp_norm: {f_perp_norm}") + print(f" f_perp_vec[0]: {f_perp_ase[0]}") + print(f" spring1_length (R[{image_index}]-R[{image_index - 1}]): {spring1.nt}") + print(f" spring2_length (R[{image_index + 1}]-R[{image_index}]): {spring2.nt}") + print(f" Length Diff (spring2.nt - spring1.nt): {spring2.nt - spring1.nt}") + print(f" f_spring_par_norm: {f_spring_par_norm}") + print(f" f_spring_par_vec[0]: {f_spring_par_ase[0]}") + print(f" neb_force_before_climb_norm: {np.linalg.norm(neb_force_ase)}") + print(" ------------------------------------") + # -------------------------------------- + + # Handle climbing image modification + is_climbing = ase_neb_obj_for_state.climb and image_index == neb_state.imax + debug_data["outputs"]["is_climbing_image"] = is_climbing + debug_data["outputs"]["imax"] = int( + neb_state.imax + ) # Ensure imax is JSON serializable + + if is_climbing: + climbing_force_ase = ( + true_force_img - 2 * f_true_dot_tau_ase * tangent_ase_normalized + ) + climbing_force_norm = np.linalg.norm(climbing_force_ase) + # No .tolist() needed + debug_data["outputs"]["climbing_force_vector"] = climbing_force_ase + debug_data["outputs"]["climbing_force_norm"] = climbing_force_norm + final_force_ase = climbing_force_ase + else: + final_force_ase = neb_force_ase + + # No .tolist() needed + debug_data["outputs"]["final_neb_force_vector"] = final_force_ase + debug_data["outputs"]["final_neb_force_norm"] = np.linalg.norm(final_force_ase) + + except Exception as e: + error_msg = ( + f"Error during manual ASE force calculation for image {image_index}: {e}" + ) + print(error_msg) + debug_data["error"] = error_msg + import traceback + + debug_data["traceback"] = traceback.format_exc() + + # Write data to JSON + try: + with open(output_filename, "w") as f: + json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder + print(f"--- ASE NEB Debug Info saved to {output_filename} ---") + except Exception as e: + print(f"Error writing ASE debug info to JSON: {e}") + + +# --- Add Function for Comparing JSON/Pickle Outputs to debug the tangent force calculation --- +def compare_step0_outputs( + file_ase="ase_step0_debug.json", + file_ts="torchsim_step0_debug.pkl", + rtol=1e-5, + atol=1e-6, +): + print("\n--- Comparing Step 0 Debug Outputs (ASE JSON vs TorchSim Pickle) --- ") + try: + # Load ASE data from JSON + with open(file_ase) as f: + data_ase = json.load(f, cls=MontyDecoder) + # Load TorchSim data from Pickle + with open(file_ts, "rb") as f: # Use 'rb' for pickle + data_ts = pickle.load(f) + except FileNotFoundError as e: + print(f"Error: Could not find file {e.filename}") + return + except Exception as e: + print(f"Error loading JSON/Pickle files: {e}") + return + + # Basic checks + if data_ase.get("error") or data_ts.get("error"): + print("Comparison aborted due to error during data generation.") + print(f" ASE Error: {data_ase.get('error')}") + print(f" TS Error: {data_ts.get('error')}") + return + + if data_ase.get("step") != 0 or data_ts.get("step") != 0: + print("Warning: One or both files do not contain step 0 data.") + # Continue comparison anyway + + if data_ase.get("image_index_intermediate") != data_ts.get( + "image_index_intermediate" + ): + print("Warning: JSON files are for different intermediate image indices.") + # Continue comparison anyway + + outputs_ase = data_ase.get("outputs", {}) + outputs_ts = data_ts.get("outputs", {}) + + all_keys = set(outputs_ase.keys()) | set(outputs_ts.keys()) + mismatches = 0 + + print( + f"Comparing fields for intermediate image index: {data_ase.get('image_index_intermediate', 'N/A')}" + ) + + for key in sorted(list(all_keys)): + val_ase = outputs_ase.get(key) + val_ts = outputs_ts.get(key) + + if key not in outputs_ts: + print(f" - Key '{key}': Present in ASE, Missing in TorchSim") + mismatches += 1 + continue + if key not in outputs_ase: + print(f" - Key '{key}': Missing in ASE, Present in TorchSim") + mismatches += 1 + continue + + # --- Handle Type Conversion for Comparison --- + val_ase_comp = val_ase + val_ts_comp = val_ts + + # Convert torch tensor from pickle to numpy/scalar for comparison + if isinstance(val_ts_comp, torch.Tensor): + if val_ts_comp.ndim == 0: # Scalar tensor + val_ts_comp = val_ts_comp.item() + else: + val_ts_comp = val_ts_comp.detach().cpu().numpy() # Use detach() + # -------------------------------------------- + + # --- Debug Print for Specific Key --- + if key == "neb_force_before_climb_vector": + print( + f" DEBUG compare [{key}]: ASE[0]={np.array(val_ase_comp)[0]}, TS[0]={np.array(val_ts_comp)[0]}" + ) + # ------------------------------------ + + # --- Special Handling for imax index --- + if key == "imax": + # ASE imax is index in full list (1 to n_images-1) + # TS imax is index in intermediates (0 to n_images-2) + # Compare ASE imax with TS imax + 1 + ase_imax = int(val_ase_comp) + ts_imax_plus_1 = int(val_ts_comp) + 1 + match = ase_imax == ts_imax_plus_1 + if not match: + difference_info = f"ASE imax={ase_imax}, TS imax(adj)={ts_imax_plus_1}" + status = "Match" if match else "DIFFER" + print(f" - Key '{key:<30}': {status} {difference_info}") + if not match: + mismatches += 1 + continue # Skip rest of comparison for imax + # ------------------------------------- + + # Try numerical comparison first + match = False + difference_info = "" + try: + # Ensure they are numpy arrays for consistent comparison + # ASE data might already be numpy or list, TS data was converted above + arr_ase = np.array(val_ase_comp) + arr_ts = np.array(val_ts_comp) + + if arr_ase.shape != arr_ts.shape: + match = False + difference_info = f"Shapes differ: ASE={arr_ase.shape}, TS={arr_ts.shape}" + elif np.issubdtype(arr_ase.dtype, np.number) and np.issubdtype( + arr_ts.dtype, np.number + ): + match = np.allclose(arr_ase, arr_ts, rtol=rtol, atol=atol) + if not match: + max_abs_diff = np.max(np.abs(arr_ase - arr_ts)) + difference_info = f"Max abs diff: {max_abs_diff:.6e}" + elif arr_ase.dtype == np.bool_ and arr_ts.dtype == np.bool_: + match = np.array_equal(arr_ase, arr_ts) + if not match: + difference_info = f"Boolean values differ: ASE={arr_ase}, TS={arr_ts}" + else: # Fallback for other types (e.g., strings if they were arrays) + match = np.array_equal(arr_ase, arr_ts) + if not match: + difference_info = "Non-numerical array values differ" + + except (TypeError, ValueError): + # Fallback to direct comparison for non-array types or incompatible arrays + try: + if isinstance(val_ase_comp, (float, int)) and isinstance( + val_ts_comp, (float, int) + ): + match = np.isclose(val_ase_comp, val_ts_comp, rtol=rtol, atol=atol) + if not match: + difference_info = f"Diff: {abs(val_ase_comp - val_ts_comp):.6e}" + elif type(val_ase_comp) == type(val_ts_comp): + match = val_ase_comp == val_ts_comp + if not match: + difference_info = ( + f"Values differ: ASE='{val_ase_comp}', TS='{val_ts_comp}'" + ) + else: + # Types should ideally match after conversion, but check just in case + match = False + difference_info = f"Types differ after conversion: ASE={type(val_ase_comp)}, TS={type(val_ts_comp)}" + except Exception: + match = False + + status = "Match" if match else "DIFFER" # Pad DIFFER for alignment + print(f" - Key '{key:<30}': {status} {difference_info}") + if not match: + mismatches += 1 + + if mismatches == 0: + print("\nAll compared output fields match.") + else: + print(f"\nFound {mismatches} mismatch(es) in output fields.") + print("--- End Comparison --- ") + + +# ------------------------------------------------- + + +# --- Add Function to Print Pickle Structure --- +def print_pickle_structure(filename="torchsim_step0_debug.pkl"): + print(f"\n--- Structure of Pickle File: {filename} --- ") + try: + with open(filename, "rb") as f: + data = pickle.load(f) + except FileNotFoundError: + print(f"Error: File not found: {filename}") + return + except Exception as e: + print(f"Error loading pickle file: {e}") + return + + if not isinstance(data, dict): + print(f"Loaded data is not a dictionary (Type: {type(data)})") + return + + print(f"Keys: {list(data.keys())}") + for key, value in data.items(): + if isinstance(value, dict): + print(f" {key}:") + for subkey, subvalue in value.items(): + val_type = type(subvalue) + val_shape = getattr(subvalue, "shape", "N/A") + # Add dtype for tensors + val_dtype = getattr(subvalue, "dtype", "N/A") + print( + f" - {subkey:<30}: Type={val_type}, Shape={val_shape}, Dtype={val_dtype}" + ) + else: + val_type = type(value) + val_shape = getattr(value, "shape", "N/A") + val_dtype = getattr(value, "dtype", "N/A") + print(f" {key:<32}: Type={val_type}, Shape={val_shape}, Dtype={val_dtype}") + print("--- End Pickle Structure --- ") + + +# -------------------------------------------- + +# --- Perform manual ASE force calculation for step 0 --- +debug_ase_img_index = ( + n_intermediate_images_ase // 2 + 1 +) # Index in the full list (0 to n_images+1) +calculate_ase_neb_force_step0(ase_neb_compare, debug_ase_img_index, neb_workflow) +# ------------------------------------------------------ + +print("\nStarting torch-sim NEB optimization...") +final_path_gd = neb_workflow.run( + initial_system=initial_system, + final_system=final_system, + max_steps=100, # Keep increased steps for now + fmax=0.05, +) +print("Finished torch-sim NEB optimization.") + +# Check if it converged and plot results +results = ts_mace_model( + dict( + positions=final_path_gd.positions, + cell=final_path_gd.cell, + atomic_numbers=final_path_gd.atomic_numbers, + system_idx=final_path_gd.system_idx, + pbc=True, + ) +) + +energies = results["energy"].tolist() + +# Including the energies from the ASE NEB calculation for comparison +# ase_energies = [0.0, 0.154541015625, 0.6151123046875, 0.8592529296875, 0.8148193359375, 0.5965576171875, 0.47705078125] + +ase_neb_calc = ase_neb(relaxed_start_atoms, relaxed_end_atoms, nimages=5) +ase_energies = [image.get_potential_energy() for image in ase_neb_calc.images] +scaled_ase_energies = [e - ase_energies[0] for e in ase_energies] + + +scaled_energies = [e - energies[0] for e in energies] + +print(scaled_energies) +torch_sim_barrier = max(scaled_energies) - scaled_energies[0] +ase_barrier = max(scaled_ase_energies) - scaled_ase_energies[0] + +# Create normalized reaction coordinates (0 to 1) for both datasets +torch_sim_coords = np.linspace(0, 1, len(scaled_energies)) +ase_coords = np.linspace(0, 1, len(scaled_ase_energies)) + +# Create a common x-axis with 100 points for smoother plotting +common_coords = np.linspace(0, 1, 100) + +# Interpolate both energy profiles to the common coordinate system +torch_sim_interp = np.interp(common_coords, torch_sim_coords, scaled_energies) +ase_interp = np.interp(common_coords, ase_coords, scaled_ase_energies) + +# --- Print Pickle Structure to Verify --- +# print_pickle_structure() +# ------------------------------------- + +# --- Compare Step 0 Debug Outputs for compute_tangent at step 0--- +# compare_step0_outputs() # Use the updated function name +# ------------------------------------ + + +# --- Plot the energy profiles --- +plt.plot(common_coords, torch_sim_interp, label="torch-sim") +plt.plot(common_coords, ase_interp, label="ASE") +plt.xlabel("Reaction Coordinate") +plt.ylabel("Energy (eV)") +plt.title( + f"ASE Barrier = {ase_barrier:.4f} eV, torch-sim Barrier = {torch_sim_barrier:.4f} eV, Difference = {torch_sim_barrier - ase_barrier:.4f} eV" +) +plt.legend() +plt.show() +# ------------------------------------ + + +# --- Function to Inspect HDF5 File Structure --- +def inspect_hdf5(filename): + print(f"\n--- Inspecting HDF5 File: {filename} ---") + try: + with h5py.File(filename, "r") as f: + + def print_attrs(name, obj): + print(f" Path: /{name}") + if isinstance(obj, h5py.Dataset): + print(" Type: Dataset") + print(f" Shape: {obj.shape}") + print(f" Dtype: {obj.dtype}") + # Optionally print a small slice of data + # try: + # print(f" Data sample: {obj[0:min(2, obj.shape[0])]}") + # except Exception as e: + # print(f" Could not read data sample: {e}") + elif isinstance(obj, h5py.Group): + print(" Type: Group") + print(f" Attributes: {dict(obj.attrs)}") + + f.visititems(print_attrs) + except FileNotFoundError: + print(f"Error: File not found: {filename}") + except Exception as e: + print(f"Error inspecting HDF5 file: {e}") + print("--- End HDF5 Inspection ---") + + +# ---------------------------------------------- + + +# --- Analyze Optimizer Convergence --- +def analyze_convergence(ts_traj_file, ase_fmax_csv_file): + print("\n--- Analyzing Optimizer Convergence ---") + max_force_ts = [] + max_force_ase = [] + + # Analyze torch-sim trajectory + try: + with h5py.File(ts_traj_file, "r") as f: + if "data/neb_forces" not in f or "data/image_indices" not in f: + raise ValueError( + "HDF5 file missing '/data/neb_forces' or '/data/image_indices' datasets." + ) + + # Data is under /data group, steps are the first dimension + neb_forces_dset = f["/data/neb_forces"] + image_indices_dset = f["/data/image_indices"] + + n_steps = neb_forces_dset.shape[0] + # Read static image indices (take the first slice) + image_indices = image_indices_dset[0, :] + + # Infer dimensions + n_images_total = len(np.unique(image_indices)) + n_atoms_total = len(image_indices) + if neb_forces_dset.shape[1] != n_atoms_total: + raise ValueError( + f"Mismatch between image_indices length ({n_atoms_total}) and neb_forces second dimension ({neb_forces_dset.shape[1]})" + ) + + n_atoms_per_image = n_atoms_total // n_images_total + print( + f"TorchSim Traj: {n_steps} steps, {n_images_total} total images, {n_atoms_per_image} atoms/image." + ) + + for step in range(n_steps): + # Access forces for the current step from the first dimension + neb_forces = torch.from_numpy(neb_forces_dset[step, :, :]) + + # Select forces only for intermediate images (index 1 to n_images_total - 2) + intermediate_mask = (image_indices > 0) & ( + image_indices < n_images_total - 1 + ) + forces_intermediate = neb_forces[intermediate_mask] + if forces_intermediate.numel() > 0: + max_comp = torch.max(torch.abs(forces_intermediate)).item() + max_force_ts.append(max_comp) + else: + max_force_ts.append(0.0) # Or handle error/empty case + + except Exception as e: + print(f"Error reading torch-sim trajectory {ts_traj_file}: {e}") + + # Read ASE fmax data from CSV + try: + # Use numpy.loadtxt to read the 2nd column (index 1) from the CSV + # Assuming tab delimiter, skipping header row + ase_data = np.loadtxt(ase_fmax_csv_file, delimiter="\t", skiprows=1, usecols=(1,)) + max_force_ase = ase_data.tolist() # Convert numpy array to list + print(f"Read {len(max_force_ase)} fmax values from {ase_fmax_csv_file}") + except Exception as e: + print(f"Error reading ASE fmax CSV file {ase_fmax_csv_file}: {e}") + + # Plotting + if max_force_ts or max_force_ase: + plt.figure() + if max_force_ts: + plt.plot( + range(len(max_force_ts)), + max_force_ts, + label="torch-sim (ase_fire)", + marker=".", + ) + if max_force_ase: + plt.plot( + range(len(max_force_ase)), max_force_ase, label="ASE (FIRE)", marker="." + ) + plt.xlabel("Optimization Step") + plt.ylabel("Max Abs Force Component (eV/Ang)") + plt.title("Optimizer Convergence Comparison") + plt.legend() + plt.grid(True) + plt.yscale("log") # Log scale often helpful for forces + plt.show() + else: + print("No force data extracted to plot convergence.") + + +# inspect_hdf5(traj_file_name) +analyze_convergence(traj_file_name, "ase_fmax_convergence.csv") +# --------------------------------- + +# --- Debugging Functions (Keep for reference) --- +# def calculate_ase_neb_force_step0(...): ... +# def compare_step0_outputs(...): ... +# def print_pickle_structure(...): ... + + +# --- Call Step 0 Debug Functions (Commented out) --- +# # Perform manual ASE force calculation for step 0 +debug_ase_img_index = n_intermediate_images_ase // 2 + 1 +calculate_ase_neb_force_step0(ase_neb_compare, debug_ase_img_index, neb_workflow) + +# # Print Pickle Structure to Verify +print_pickle_structure() + +# # Compare Step 0 Debug Outputs +compare_step0_outputs() +# -------------------------------------------------- diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py new file mode 100644 index 00000000..8be7027b --- /dev/null +++ b/torch_sim/workflows/neb.py @@ -0,0 +1,887 @@ +"""Nudged Elastic Band (NEB) workflow. + +This module implements the Nudged Elastic Band method for finding minimum energy +paths between two given atomic configurations. +""" + +import inspect +import logging +import os # Import os for fsync +import pickle # Import pickle +from collections.abc import Callable +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Any, Literal + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import ( + CellFireState, + FireState, + OptimState, + fire_init, + fire_step, + gradient_descent_init, + gradient_descent_step, +) +from torch_sim.optimizers.cell_filters import CellFilter +from torch_sim.state import ( + SimState, + concatenate_states, + initialize_state, +) +from torch_sim.trajectory import TorchSimTrajectory +from torch_sim.transforms import minimum_image_displacement +from torch_sim.typing import StateLike + + +logger = logging.getLogger(__name__) + +# Add epsilon for numerical stability +_EPS = torch.finfo(torch.float64).eps + + +def _extract_kwargs_from_params( + params: dict[str, Any], func: Callable[..., Any], exclude: set[str] | None = None +) -> dict[str, Any]: + """Extract kwargs from params dict that match function signature. + + Args: + params: Dictionary of parameters to filter + func: Function to extract parameters for + exclude: Set of parameter names to exclude (e.g., 'state', 'model') + + Returns: + Dictionary of parameters that match the function signature + """ + exclude = exclude or {"state", "model"} + sig = inspect.signature(func) + return { + k: v + for k, v in params.items() + if k in sig.parameters and k not in exclude + } + + +@dataclass +class _OptimizerConfig: + """Configuration for an optimizer type.""" + + init_fn: Callable[..., Any] + step_fn: Callable[..., Any] + state_type: type + init_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None + step_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None + + +# Registry of optimizer configurations +_OPTIMIZER_REGISTRY: dict[str, _OptimizerConfig] = { + "fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + state_type=FireState, + ), + "frechet_cell_fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + state_type=CellFireState, + init_kwargs_modifier=lambda kwargs: {**kwargs, "cell_filter": CellFilter.frechet}, + ), + "gd": _OptimizerConfig( + init_fn=gradient_descent_init, + step_fn=gradient_descent_step, + state_type=OptimState, + step_kwargs_modifier=lambda kwargs: ( + kwargs if "pos_lr" in kwargs else {**kwargs, "pos_lr": kwargs.get("lr", 0.01)} + ), + ), + "ase_fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + state_type=FireState, + init_kwargs_modifier=lambda kwargs: ( + kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} + ), + step_kwargs_modifier=lambda kwargs: ( + kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} + ), + ), +} + + +@dataclass +class NEB: + """Nudged Elastic Band (NEB) optimizer. + + Finds the minimum energy path (MEP) between an initial and final state using + the NEB algorithm. + + Attributes: + model: The energy/force model (e.g., MACE) wrapped in a ModelInterface. + n_images: Number of intermediate images between initial and final states. + spring_constant: Spring constant connecting adjacent images (eV/Ang^2). + use_climbing_image: Whether to use a climbing image. + optimizer_type: Type of optimizer to use. + optimizer_params: Parameters for the chosen optimizer. + trajectory_filename: Optional filename for saving the NEB trajectory. + device: Computation device (e.g., 'cpu', 'cuda'). If None, uses model device. + dtype: Computation data type (e.g., torch.float32). If None, uses model dtype. + """ + + model: ModelInterface + n_images: int + spring_constant: float = 0.1 # eV/Ang^2, typical ASE default + use_climbing_image: bool = False + optimizer_type: Literal["fire", "gd", "frechet_cell_fire", "ase_fire"] = "fire" + optimizer_params: dict[str, Any] = field(default_factory=dict) + trajectory_filename: str | None = None + device: torch.device | None = None + dtype: torch.dtype | None = None + + def __post_init__(self) -> None: + """Initializes device, dtype, and optimizer functions after dataclass creation.""" + if self.device is None: + self.device = self.model.device + if self.dtype is None: + self.dtype = self.model.dtype + + # Initialize variable to store step 0 debug output + self._step0_debug_output = None + + # Get optimizer configuration from registry + if self.optimizer_type not in _OPTIMIZER_REGISTRY: + raise ValueError( + f"Unsupported optimizer_type: {self.optimizer_type}. " + f"Supported types: {list(_OPTIMIZER_REGISTRY.keys())}" + ) + + config = _OPTIMIZER_REGISTRY[self.optimizer_type] + self._init_fn = config.init_fn + self._step_fn = config.step_fn + self._OptimizerStateType = config.state_type + + # Automatically extract kwargs from optimizer_params based on function signatures + # For init: exclude 'state' and 'model' (positional args) + # For step: exclude 'state' and 'model' (positional args) + init_kwargs = _extract_kwargs_from_params( + self.optimizer_params, config.init_fn, exclude={"state", "model"} + ) + step_kwargs = _extract_kwargs_from_params( + self.optimizer_params, config.step_fn, exclude={"state", "model"} + ) + + # Apply modifiers if provided (for special cases like cell_filter, defaults, etc.) + if config.init_kwargs_modifier: + init_kwargs = config.init_kwargs_modifier(init_kwargs) + if config.step_kwargs_modifier: + step_kwargs = config.step_kwargs_modifier(step_kwargs) + + self._init_kwargs = init_kwargs + self._step_kwargs = step_kwargs + + def _interpolate_path( + self, initial_state: SimState, final_state: SimState + ) -> SimState: + """Linearly interpolate the initial path between states using MIC. + + Generates `n_images` intermediate states between the initial and final states + by linear interpolation of atomic positions, respecting periodic boundary + conditions via the Minimum Image Convention (MIC). + + Args: + initial_state (SimState): The starting SimState (must be single-batch). + final_state (SimState): The ending SimState (must be single-batch). + + Returns: + SimState: A single SimState containing all interpolated intermediate + images, batched together. The batch index corresponds to the image + index (0 to n_images-1). + + Raises: + ValueError: If initial and final states are incompatible (e.g., different + number of atoms, atom types, PBC settings, or if they are not + single-batch states). + """ + # --- Input Validation --- + if initial_state.n_systems != 1 or final_state.n_systems != 1: + raise ValueError("Initial and final states must be single-system SimStates.") + if initial_state.n_atoms != final_state.n_atoms: + raise ValueError( + f"Initial ({initial_state.n_atoms}) and final ({final_state.n_atoms}) " + "states must have the same number of atoms." + ) + if not torch.equal(initial_state.atomic_numbers, final_state.atomic_numbers): + # Comparing floats might be tricky, but atomic numbers should be exact + raise ValueError("Initial and final states must have the same atom types.") + # Compare PBC values properly (can be bool, list, or tensor) + pbc_match = False + if isinstance(initial_state.pbc, torch.Tensor) and isinstance(final_state.pbc, torch.Tensor): + pbc_match = torch.equal(initial_state.pbc, final_state.pbc) + elif isinstance(initial_state.pbc, torch.Tensor) or isinstance(final_state.pbc, torch.Tensor): + # One is tensor, one is not - convert both to tensors for comparison + initial_pbc_tensor = ( + initial_state.pbc + if isinstance(initial_state.pbc, torch.Tensor) + else torch.tensor(initial_state.pbc, device=initial_state.device) + ) + final_pbc_tensor = ( + final_state.pbc + if isinstance(final_state.pbc, torch.Tensor) + else torch.tensor(final_state.pbc, device=final_state.device) + ) + pbc_match = torch.equal(initial_pbc_tensor, final_pbc_tensor) + else: + # Both are bools or lists + pbc_match = initial_state.pbc == final_state.pbc + if not pbc_match: + # TODO: Could potentially support different PBCs, but complex for NEB. + raise ValueError("Initial and final states must have the same PBC setting.") + # For fixed-cell NEB, cells should ideally be identical. Warn if not? + # if not torch.allclose(initial_state.cell, final_state.cell): + + n_atoms_per_image = initial_state.n_atoms + + # --- Interpolation --- + initial_pos = initial_state.positions + final_pos = final_state.positions + + # Calculate displacement using Minimum Image Convention + displacement = minimum_image_displacement( + dr=final_pos - initial_pos, + cell=initial_state.cell[0], # Use cell from initial state + pbc=initial_state.pbc, + ) + # Ensure shape is correct [n_atoms, 3] + displacement = displacement.reshape(n_atoms_per_image, 3) + + # Generate interpolation factors (e.g., for n_images=3: 0.25, 0.5, 0.75) + factors = torch.linspace( + 0.0, 1.0, steps=self.n_images + 2, device=self.device, dtype=self.dtype + )[1:-1] # Exclude 0.0 and 1.0 # Ensure dtype + factors = factors.view(-1, 1, 1) # Shape: [n_images, 1, 1] + + # Calculate interpolated positions: initial + factor * displacement + # Broadcasting: [N_atoms, 3] + [N_images, 1, 1] * [N_atoms, 3] -> [N_images, N_atoms, 3] + interpolated_pos = initial_pos.unsqueeze(0) + factors * displacement.unsqueeze(0) + + # Reshape to [n_images * n_atoms_per_image, 3] + all_positions = interpolated_pos.reshape(-1, 3) + + # --- Create Batched State --- + # Repeat other attributes for each image + all_atomic_numbers = initial_state.atomic_numbers.repeat(self.n_images) + all_masses = initial_state.masses.repeat(self.n_images) + # Use initial state's cell, repeated for each image + all_cells = initial_state.cell.repeat( + self.n_images, 1, 1 + ) # Shape: [n_images, 3, 3] + + # Create system_idx tensor: [0, 0, ..., 1, 1, ..., n_images-1, ...] + system_indices = torch.arange(self.n_images, device=self.device, dtype=torch.int64) + all_system_idx = torch.repeat_interleave(system_indices, repeats=n_atoms_per_image) + + return SimState( + positions=all_positions, + atomic_numbers=all_atomic_numbers, + masses=all_masses, + cell=all_cells, + pbc=initial_state.pbc, + system_idx=all_system_idx, + ) + + def _compute_tangents( + self, + all_pos: torch.Tensor, # Shape: [n_total_images, n_atoms, 3] + all_energies: torch.Tensor, # Shape: [n_total_images] + cell: torch.Tensor, # Shape: [3, 3] + *, # Make pbc keyword-only + pbc: bool, + ) -> torch.Tensor: + """Compute normalized tangent vectors for intermediate NEB images. + + Implements the improved tangent estimate of Henkelman and Jónsson (2000) + to determine the local tangent direction at each intermediate image based + on the positions and energies of its neighbors. + + Args: + all_pos (torch.Tensor): Atomic configurations for all images in the path + (initial + intermediate + final), shape [n_total_images, n_atoms, 3]. + all_energies (torch.Tensor): Potential energy of each image, shape + [n_total_images]. + cell (torch.Tensor): Unit cell vectors (shape [3, 3]), assumed constant + for the path. + pbc (bool): Flag indicating if periodic boundary conditions are active. + + Returns: + torch.Tensor: Normalized local tangent vectors for the intermediate + images only, shape [n_images, n_atoms, 3]. Tangents are zero for + numerically identical adjacent images. + """ + n_total_images, n_atoms_per_image, _ = all_pos.shape + n_intermediate_images = n_total_images - 2 + device = all_pos.device + dtype = all_pos.dtype + + # Initialize tangents for intermediate images only + tangents = torch.zeros( + (n_intermediate_images, n_atoms_per_image, 3), + device=device, + dtype=self.dtype, # Use self.dtype + ) + + # Calculate displacements between adjacent images using MIC + # dR_forward[i] = R_{i+1} - R_i + displacements = minimum_image_displacement( + dr=all_pos[1:] - all_pos[:-1], cell=cell, pbc=pbc + ) + # Ensure shape is correct after MIC if needed + displacements = displacements.reshape(n_total_images - 1, n_atoms_per_image, 3) + + # Energy differences V_{i+1} - V_i + dE_forward = all_energies[1:] - all_energies[:-1] # Shape: [n_total_images - 1] + + # Compute tangents for intermediate images (indices 1 to N in all_pos) + for i in range(n_intermediate_images): + img_idx = i + 1 # Index in all_pos, all_energies + + # Displacements adjacent to image `img_idx` + # Note: displacements[k] is R_{k+1} - R_k + dR_plus = displacements[img_idx] # R_{i+1} - R_i (where i = img_idx) + dR_minus = displacements[img_idx - 1] # R_i - R_{i-1} (where i = img_idx) + + # Energy differences adjacent to image `img_idx` + dE_plus = dE_forward[img_idx] # V_{i+1} - V_i + dE_minus = dE_forward[img_idx - 1] # V_i - V_{i-1} + + # Select tangent based on energy profile (Henkelman & Jónsson criteria) + tangent_i = torch.zeros_like(dR_plus) + + # Condition 1: Ascending segment (minimum) V_{i+1}>V_i and V_i>V_{i-1} => dE_plus>0 and dE_minus>0 + if dE_plus > 0 and dE_minus > 0: + tangent_i = ( + dR_plus # ASE uses forward difference (dR_plus = R[i+1] - R[i]) + ) + + # Condition 2: Descending segment (maximum) V_{i+1} dE_plus<0 and dE_minus<0 + elif ( + dE_plus < 0 and dE_minus < 0 + ): # Check if dE_minus comparison is correct (<0 vs >0) + # tangent_i = dR_plus if abs(dE_plus) < abs(dE_minus) else dR_minus # Old complex version + # ASE logic: if E[i+1] < E[i] < E[i-1], tangent = dR_minus (spring1.t) -> Mismatch? + # Let's assume torch-sim should match ASE exactly: + tangent_i = ( + dR_minus # ASE uses backward difference (dR_minus = R[i] - R[i-1]) + ) + + # Condition 3: Other cases (weighted average in ASE) + else: + # Implement ASE's weighting logic precisely + # Note: ASE uses absolute values for deltavmax/min calculation + abs_dE_plus = torch.abs(dE_plus) + abs_dE_minus = torch.abs(dE_minus) + + deltavmax = torch.maximum(abs_dE_plus, abs_dE_minus) + deltavmin = torch.minimum(abs_dE_plus, abs_dE_minus) + + # Check E[i+1] vs E[i-1] + # E[i+1] - E[i-1] = dE_plus + dE_minus + if (dE_plus + dE_minus) > 0: # E[i+1] > E[i-1] + tangent_i = dR_plus * deltavmax + dR_minus * deltavmin + else: # E[i+1] <= E[i-1] + tangent_i = dR_plus * deltavmin + dR_minus * deltavmax + + # Normalize the tangent vector *within* the loop + norm_i = torch.linalg.norm(tangent_i) + if norm_i > _EPS: + tangents[i] = tangent_i / norm_i + # else: tangent remains zero if norm is too small + + return tangents + + def _calculate_neb_forces( + self, + path_state: SimState, + true_forces: torch.Tensor, + true_energies: torch.Tensor, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + step: int, + ) -> tuple[torch.Tensor, dict | None]: # Return forces and optional debug data + """Calculate the NEB forces for intermediate images. + + The NEB force is composed of the true force perpendicular to the path tangent + and the spring force parallel to the path tangent. Handles climbing image + force modification if enabled. + + Args: + path_state (SimState): SimState containing the full path (initial + + intermediate + final images). Batches are assumed to be ordered. + true_forces (torch.Tensor): Forces from the potential energy model for + the *intermediate* images only, shape [n_movable_atoms, 3]. + true_energies (torch.Tensor): Potential energies for the *intermediate* + images only, shape [n_images]. + initial_energy (torch.Tensor): Potential energy of the initial state + (scalar tensor). + final_energy (torch.Tensor): Potential energy of the final state + (scalar tensor). + step (int): Current optimization step number (used for climbing image delay). + + Returns: + torch.Tensor: Calculated NEB forces for the intermediate images, ready to + be passed to the optimizer, shape [n_movable_atoms, 3]. + """ + n_total_images = path_state.n_systems + n_intermediate_images = n_total_images - 2 + assert n_intermediate_images == self.n_images + n_atoms_per_image = path_state.n_atoms // n_total_images + + # --- Reshape inputs --- + # Positions for all images: [n_total_images, n_atoms, 3] + all_pos = path_state.positions.reshape(n_total_images, n_atoms_per_image, 3) + # True forces for intermediate images: [n_images, n_atoms, 3] + true_forces_reshaped = true_forces.reshape( + n_intermediate_images, n_atoms_per_image, 3 + ) + # Cell vectors (assuming fixed cell for now, take from first batch) + cell = path_state.cell[0] # Shape [3, 3] + # Convert pbc to bool if it's a tensor (for _compute_tangents) + if isinstance(path_state.pbc, torch.Tensor): + pbc_bool: bool = bool(path_state.pbc.any().item()) # True if any dimension has PBC + elif isinstance(path_state.pbc, bool): + pbc_bool = path_state.pbc + elif isinstance(path_state.pbc, list): + pbc_bool = bool(any(path_state.pbc)) + else: + pbc_bool = True + pbc = path_state.pbc # Keep original for minimum_image_displacement + + # --- Get Energies for Tangent Calculation --- + all_energies = torch.cat( + [ + initial_energy.unsqueeze(0), + true_energies, + final_energy.unsqueeze(0), + ] + ) + + # --- Setup for Debugging Step 0 --- + log_step_0 = step == 0 + debug_img_idx = ( + n_intermediate_images // 2 + ) # Index within intermediates (0 to n_images-1) + debug_img_idx_all = debug_img_idx + 1 # Index within all_pos (0 to n_images+1) + debug_data_ts = {} # Initialize debug dict + + if log_step_0: + debug_data_ts = { + "step": 0, + "image_index_intermediate": debug_img_idx, + "image_index_absolute": debug_img_idx_all, + "inputs": {}, + "outputs": {}, + "error": None, + } + debug_data_ts["inputs"]["energies_all"] = all_energies # Monty handles tensor + debug_data_ts["inputs"]["cell"] = cell + debug_data_ts["inputs"]["pbc"] = pbc_bool # Store Python bool + debug_data_ts["inputs"]["positions_image_minus_1"] = all_pos[ + debug_img_idx_all - 1 + ] + debug_data_ts["inputs"]["positions_image"] = all_pos[debug_img_idx_all] + debug_data_ts["inputs"]["positions_image_plus_1"] = all_pos[ + debug_img_idx_all + 1 + ] + debug_data_ts["inputs"]["true_forces_image"] = true_forces_reshaped[ + debug_img_idx + ] + + # --- Calculate Tangents (tau) using the improved method --- + # tangents shape: [n_images, n_atoms, 3] + tangents = self._compute_tangents(all_pos, all_energies, cell, pbc=pbc_bool) + logger.debug( + f" Step {step}: Tangent norms per image: {torch.linalg.norm(tangents, dim=(-1, -2))}" + ) + if log_step_0: + # Note: ASE tangent might not be normalized if norm is ~0, TS tangent should be. + tangent_img = tangents[debug_img_idx] + tangent_norm_img = torch.linalg.norm(tangent_img) + debug_data_ts["outputs"]["tangent_vector"] = tangent_img + debug_data_ts["outputs"]["tangent_norm"] = tangent_norm_img + + # --- Calculate Displacements for Spring Force --- + # Recalculate here or reuse from _compute_tangents if efficient + displacements = minimum_image_displacement( + dr=all_pos[1:] - all_pos[:-1], cell=cell, pbc=pbc + ) + displacements = displacements.reshape(n_total_images - 1, n_atoms_per_image, 3) + if log_step_0: + # Save displacements relevant to the middle image's tangent/spring calculation + debug_data_ts["outputs"]["mic_displacement_1"] = displacements[ + debug_img_idx_all - 1 + ] # R(i) - R(i-1) + debug_data_ts["outputs"]["mic_displacement_2"] = displacements[ + debug_img_idx_all + ] # R(i+1) - R(i) + + # --- Calculate NEB Force Components --- + + # 1. Perpendicular component of true force + # F_perp = F_true - (F_true . tau) * tau + # Dot product (sum over atoms and dims): [n_images] + F_true_dot_tau = (true_forces_reshaped * tangents).sum(dim=(-1, -2), keepdim=True) + F_perp = true_forces_reshaped - F_true_dot_tau * tangents + logger.debug( + f" Step {step}: F_perp norms per image: {torch.linalg.norm(F_perp, dim=(-1, -2))}" + ) + if log_step_0: + f_perp_img = F_perp[debug_img_idx] + f_perp_norm_img = torch.linalg.norm(f_perp_img) + debug_data_ts["outputs"]["f_true_dot_tau"] = F_true_dot_tau[ + debug_img_idx + ].item() # scalar + debug_data_ts["outputs"]["f_perp_vector"] = f_perp_img + debug_data_ts["outputs"]["f_perp_norm"] = f_perp_norm_img + + # 2. Parallel component of spring force + # F_spring_par = k * (|R_{i+1}-R_i| - |R_i-R_{i-1}|) * tau_i + # Segment lengths (scalar magnitude per segment): [n_images+1] + segment_lengths = torch.linalg.norm( + displacements, dim=(-1, -2) + ) # Cleaner way [n_total_images-1] + # Spring force magnitude (scalar per intermediate image): [n_images] + F_spring_mag = self.spring_constant * (segment_lengths[1:] - segment_lengths[:-1]) + # Project onto tangent: [n_images, 1, 1] -> [n_images, n_atoms, 3] + F_spring_par = F_spring_mag.view(-1, 1, 1) * tangents + logger.debug( + f" Step {step}: F_spring_par norms per image: {torch.linalg.norm(F_spring_par, dim=(-1, -2))}" + ) + if log_step_0: + f_spring_par_img = F_spring_par[debug_img_idx] + f_spring_par_norm_img = torch.linalg.norm(f_spring_par_img) + debug_data_ts["outputs"]["segment_lengths"] = segment_lengths # Full list + debug_data_ts["outputs"]["spring_force_magnitude_term"] = F_spring_mag[ + debug_img_idx + ].item() # scalar + debug_data_ts["outputs"]["f_spring_par_vector"] = f_spring_par_img + debug_data_ts["outputs"]["f_spring_par_norm"] = f_spring_par_norm_img + + # --- Combine Components for NEB Force --- + neb_forces = F_perp + F_spring_par + if log_step_0: + # --- Direct Debug Logs for Step 0 --- + f_perp_img = F_perp[debug_img_idx] + f_spring_par_img = F_spring_par[debug_img_idx] + neb_force_img = neb_forces[debug_img_idx] + logger.debug(" --- DIRECT DEBUG LOG (TORCH-SIM STEP 0) ---") + logger.debug(f" f_perp_norm: {torch.linalg.norm(f_perp_img)}") + logger.debug(f" f_perp_vec[0]: {f_perp_img[0]}") + # segment_lengths shape: [n_total_images - 1] + # segment_lengths[debug_img_idx] corresponds to spring2 length + # segment_lengths[debug_img_idx-1] corresponds to spring1 length + len1 = segment_lengths[debug_img_idx - 1] + len2 = segment_lengths[debug_img_idx] + len_diff = len2 - len1 + logger.debug( + f" spring1_length (R[{debug_img_idx_all}]-R[{debug_img_idx_all - 1}]): {len1}" + ) + logger.debug( + f" spring2_length (R[{debug_img_idx_all + 1}]-R[{debug_img_idx_all}]): {len2}" + ) + logger.debug(f" Length Diff (len2 - len1): {len_diff}") + logger.debug(f" f_spring_par_norm: {torch.linalg.norm(f_spring_par_img)}") + logger.debug(f" f_spring_par_vec[0]: {f_spring_par_img[0]}") + logger.debug( + f" neb_force_before_climb_norm: {torch.linalg.norm(neb_force_img)}" + ) + # -------------------------------------- + # Store a *copy* detached from the graph to prevent modification by climbing image logic + debug_data_ts["outputs"]["neb_force_before_climb_vector"] = ( + neb_forces[debug_img_idx].clone().detach() + ) + debug_data_ts["outputs"]["neb_force_before_climb_norm"] = torch.linalg.norm( + neb_forces[debug_img_idx] + ) # Norm calculation is fine + + # --- Log the vector right before it would be saved --- + logger.debug( + f" Value assigned to debug_data[neb_force_before_climb_vector][0]: {neb_forces[debug_img_idx][0]}" + ) + # ----------------------------------------------------- + + # --- Handle Climbing Image --- + climbing_delay_steps = 10 # Example value + if ( + self.use_climbing_image and n_intermediate_images > 0 + ): # and step >= climbing_delay_steps: # Check step number - REMOVED DELAY + # Find index of highest energy image among intermediates + climbing_image_idx = torch.argmax( + true_energies + ).item() # Index from 0 to n_images-1 + # Calculate the climbing force: F_climb = F_true - 2 * (F_true . tau) * tau + F_climb = true_forces_reshaped[climbing_image_idx] - ( + 2 * F_true_dot_tau[climbing_image_idx] * tangents[climbing_image_idx] + ) + # Replace the NEB force for the climbing image with F_climb + # This overwrites the spring force component for this image, as required. + neb_forces[climbing_image_idx] = F_climb + logger.debug( + f" Step {step}: Climbing image index: {climbing_image_idx}, " + f"Climbing Force Norm: {torch.linalg.norm(F_climb)}" + ) + if log_step_0: + debug_data_ts["outputs"]["is_climbing_image"] = ( + climbing_image_idx == debug_img_idx + ) + debug_data_ts["outputs"]["imax"] = climbing_image_idx + debug_data_ts["outputs"]["climbing_force_vector"] = neb_forces[ + climbing_image_idx + ] + debug_data_ts["outputs"]["climbing_force_norm"] = torch.linalg.norm( + neb_forces[climbing_image_idx] + ) + + # --- Logging (Optional) --- + # logger.debug( + # " Max True Force Mag: " + # f"{torch.linalg.norm(true_forces_reshaped, dim=(-1,-2)).max().item():.4f}" + # ) + # logger.debug( + # " Max F_perp Mag: " + # f"{torch.linalg.norm(F_perp, dim=(-1,-2)).max().item():.4f}" + # ) + # logger.debug( + # " Max F_spring_par Mag: " + # f"{torch.linalg.norm(F_spring_par, dim=(-1,-2)).max().item():.4f}" + # ) + # logger.debug( + # " Max NEB Force Mag: " + # f"{torch.linalg.norm(neb_forces, dim=(-1,-2)).max().item():.4f}" + # ) + logger.debug( + f" Step {step}: NEB force norms per image: {torch.linalg.norm(neb_forces, dim=(-1, -2))}" + ) + logger.debug(f" Step {step}: Intermediate energies: {true_energies}") + if log_step_0 and not ( + self.use_climbing_image and climbing_image_idx == debug_img_idx + ): # Avoid logging twice if climbing image was logged + # If not the climbing image, the final force is the one before modification + pass # Already stored neb_force_before_climb + + if log_step_0: + debug_data_ts["outputs"]["final_neb_force_vector"] = neb_forces[debug_img_idx] + debug_data_ts["outputs"]["final_neb_force_norm"] = torch.linalg.norm( + neb_forces[debug_img_idx] + ) + + # --- Reshape output --- + final_neb_forces = neb_forces.reshape(-1, 3) # [n_movable_atoms, 3] + + # Return forces and the debug dictionary if step 0 + return final_neb_forces, debug_data_ts if log_step_0 else None + + def run( + self, + initial_system: StateLike, + final_system: StateLike, + max_steps: int = 100, + fmax: float = 0.05, + # TODO: add convergence criteria, batching options, output frequency etc. + ) -> SimState: + """Run the Nudged Elastic Band optimization. + + Optimizes the path between the initial and final systems to find the + Minimum Energy Path (MEP). + + Args: + initial_system (StateLike): The starting configuration (can be ASE Atoms, + SimState, or other compatible format recognized by initialize_state). + final_system (StateLike): The ending configuration. + max_steps (int): Maximum number of optimization steps allowed. + fmax (float): Convergence criterion based on the maximum NEB force component + acting on any single atom across all intermediate images (in eV/Ang). + + Returns: + SimState: The final optimized NEB path, including the initial, + intermediate, and final images, concatenated into a single SimState. + SimState: The final optimized NEB path, including the initial, + intermediate, and final images, concatenated into a single SimState. + """ + logger.info("Starting NEB optimization") + + # Reset step 0 debug output storage for this run + self._step0_debug_output = None + + # 1. Initialize initial and final states + initial_state = initialize_state(initial_system, self.device, self.dtype) + final_state = initialize_state(final_system, self.device, self.dtype) + # TODO: Add checks (e.g., same number of atoms, atom types) + # Ensure endpoints are single-system SimStates + # (They should already be from initialize_state, but verify) + if initial_state.n_systems != 1: + raise ValueError("Initial state must be a single-system SimState") + if final_state.n_systems != 1: + raise ValueError("Final state must be a single-system SimState") + + # 1b. Calculate endpoint energies/forces (needed for tangent calculation) + # Note: Forces aren't strictly needed here but model usually returns both + logger.info("Calculating endpoint energies...") + # Concatenate expects a list of SimStates (or subclasses) + endpoint_states = concatenate_states([initial_state, final_state]) + endpoint_output = self.model(endpoint_states) + initial_energy = endpoint_output["energy"][0] + final_energy = endpoint_output["energy"][1] + logger.info( + f"Initial Energy: {initial_energy:.4f}, Final Energy: {final_energy:.4f}" + ) + + # 2. Create initial interpolated path (movable images only) + interpolated_images = self._interpolate_path(initial_state, final_state) + + # 3. Initialize optimizer state for the movable images + # Use the generic initializer with model parameter + opt_state = self._init_fn( + interpolated_images, self.model, **self._init_kwargs + ) + + # 4. Optimization loop + logger.info(f"Running NEB for max {max_steps} steps or fmax < {fmax} eV/Ang.") + + # Context manager for trajectory writing + traj_context = ( + TorchSimTrajectory(self.trajectory_filename, mode="w") + if self.trajectory_filename + else nullcontext() # Use a dummy context if no filename + ) + + with traj_context as traj: + for step in range(max_steps): + # a. Get current true forces and energies + true_forces = opt_state.forces + true_energies = opt_state.energy + + # b. Calculate NEB forces + # Concatenate states - ensures consistent group ID (0 for single NEB) + full_path_state_calc = concatenate_states( + [initial_state, opt_state, final_state] + ) + # Store true forces *before* calculating NEB forces + true_forces_for_traj = opt_state.forces.clone() + + # Get forces and potentially the step 0 debug data + neb_forces, step0_debug_data = self._calculate_neb_forces( + full_path_state_calc, + true_forces, # Pass the forces from the start of the step + true_energies, + initial_energy, + final_energy, + step=step, + ) + + # c. Update the forces in the FIRE state object with NEB forces + opt_state.forces = neb_forces + neb_forces_for_traj = neb_forces.clone() + + # d. Perform optimization step + # Use the generic step function with model parameter + opt_state = self._step_fn(opt_state, self.model, **self._step_kwargs) + + # *** Store Step 0 Debug Data AFTER optimizer step *** + if step == 0 and step0_debug_data: + logger.info("Storing Step 0 TorchSim debug data.") + self._step0_debug_output = step0_debug_data + # *************************************************** + + # e. Write to trajectory (if enabled) + if self.trajectory_filename is not None: # Use explicit check + # Create the full path state for writing (including endpoints) + current_full_path = concatenate_states( + [initial_state, opt_state, final_state] + ) + # Write arrays directly using traj.write_arrays + data_to_write = { + "positions": current_full_path.positions, + # Add forces - Need to handle endpoints (no NEB forces) + # Pad NEB forces with zeros for endpoints + "neb_forces": torch.cat( + [ + torch.zeros_like(initial_state.positions), + neb_forces_for_traj, + torch.zeros_like(final_state.positions), + ], + dim=0, + ), + # True forces are only calculated for intermediate images + # Need forces for endpoints from the initial calculation + "true_forces": torch.cat( + [ + endpoint_output["forces"][ + : initial_state.n_atoms + ], # Initial forces + true_forces_for_traj, # Intermediate forces + endpoint_output["forces"][ + initial_state.n_atoms : + ], # Final forces + ], + dim=0, + ), + "energies": torch.cat( + [ + initial_energy.unsqueeze(0), + opt_state.energy, # Energies *after* the step + final_energy.unsqueeze(0), + ], + dim=0, + ), + } + if step == 0: # Write static data only on the first step + # Assuming fixed cell NEB, cell is static + data_to_write["cell"] = current_full_path.cell + # These should also be static for the whole band + data_to_write["atomic_numbers"] = current_full_path.atomic_numbers + data_to_write["masses"] = current_full_path.masses + # Convert bool to tensor for saving + data_to_write["pbc"] = torch.tensor(current_full_path.pbc) + # Save the system_idx tensor to map atoms to images + data_to_write["image_indices"] = current_full_path.system_idx + + traj.write_arrays(data_to_write, steps=step) + + # f. Check convergence + max_force_magnitude = torch.sqrt((neb_forces**2).sum(dim=-1)).max() + max_intermediate_energy = opt_state.energy.max() + logger.info( + f"Step {step + 1:4d}: Max Force = {max_force_magnitude:.4f} Max Energy = {max_intermediate_energy:.4f}" + # f"Energy = {fire_state.energy.mean():.4f} eV (mean per image), " # Removed mean energy for brevity + ) + if max_force_magnitude < fmax: + logger.info("NEB optimization converged.") + break + else: # Loop finished without break + logger.warning("NEB optimization did not converge within max_steps.") + + # 5. Return the final path (including endpoints) + # --- Write Step 0 Debug Dictionary AFTER loop finishes --- + if self._step0_debug_output: + output_filename_ts = "torchsim_step0_debug.pkl" # Change extension + logger.info( + f"Attempting to write final Step 0 TorchSim debug data to {output_filename_ts}" + ) + try: + with open(output_filename_ts, "wb") as f: # Use 'wb' for pickle + pickle.dump(self._step0_debug_output, f) + f.flush() + os.fsync(f.fileno()) + logger.info( + f"--- TorchSim NEB Debug Info (Step 0) saved to {output_filename_ts} ---" + ) + except Exception as e: + logger.error( + f"ERROR WRITING FINAL TORCHSIM STEP 0 DEBUG PICKLE: {e}", + exc_info=True, + ) + else: + logger.warning("No Step 0 TorchSim debug data was stored to write.") + # ---------------------------------------------------------- + + return concatenate_states([initial_state, opt_state, final_state])