Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 350 additions & 0 deletions torch_sim/neighbors/nbody.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
"""Pure-PyTorch triplet and quadruplet interaction index builders.

All functions use only standard PyTorch ops (argsort, bincount, repeat_interleave,
boolean masking, etc.) so they are compatible with ``torch.jit.script``,
``torch.compile``, and AOTInductor — no ``torch_scatter`` or ``torch_sparse``.

Typical usage::

from torch_sim.neighbors import torchsim_nl
from torch_sim.neighbors.nbody import build_triplets, build_quadruplets

mapping, system_mapping, shifts_idx = torchsim_nl(
positions, cell, pbc, cutoff, system_idx
)
trip = build_triplets(mapping, n_atoms)
# trip["trip_in"], trip["trip_out"] index into edges
"""

from __future__ import annotations

import torch


def _inner_idx(sorted_idx: torch.Tensor, dim_size: int) -> torch.Tensor:
"""Local enumeration within sorted contiguous segments.

For a sorted index tensor ``[0,0,0,1,1,2,2,2,2]`` returns ``[0,1,2,0,1,0,1,2,3]``.

Args:
sorted_idx: 1-D tensor of segment ids, **must be sorted**.
dim_size: Total number of segments (>= max(sorted_idx)+1).

Returns:
1-D tensor same length as *sorted_idx* with per-segment local indices.
"""
counts = torch.bincount(sorted_idx, minlength=dim_size)
offsets = counts.cumsum(0) - counts
return (
torch.arange(sorted_idx.size(0), device=sorted_idx.device) - offsets[sorted_idx]
)


def build_triplets(
edge_index: torch.Tensor,
n_atoms: int,
) -> dict[str, torch.Tensor]:
"""Build triplet interaction indices from an edge list.

For every pair of edges ``(b→a)`` and ``(c→a)`` that share the same target
atom ``a`` with ``edge_ba ≠ edge_ca``, produces a triplet ``b→a←c``.

Uses only ops that are JIT/AOTInductor safe: ``argsort``, ``bincount``,
``repeat_interleave``, and boolean indexing.

Args:
edge_index: ``[2, n_edges]`` tensor where ``edge_index[0]`` are sources
and ``edge_index[1]`` are targets.
n_atoms: Total number of atoms (used for bincount sizing).

Returns:
Dict with keys:

- ``"trip_in"`` — edge indices of the *incoming* edge ``b→a``, shape
``[n_triplets]``.
- ``"trip_out"`` — edge indices of the *outgoing* edge ``c→a``, shape
``[n_triplets]``.
- ``"trip_out_agg"`` — per-segment local index for aggregation, shape
``[n_triplets]``.
- ``"center_atom"`` — atom index ``a`` for each triplet, shape
``[n_triplets]``.
"""
targets = edge_index[1] # target atoms
n_edges = targets.size(0)
device = targets.device

# Sort edges by target atom to get contiguous groups
order = torch.argsort(targets, stable=True)
sorted_targets = targets[order]

# Degree per atom and CSR-style offsets
deg = torch.bincount(sorted_targets, minlength=n_atoms)
offsets = torch.zeros(n_atoms + 1, dtype=torch.long, device=device)
offsets[1:] = deg.cumsum(0)

# Number of ordered triplets per atom: deg*(deg-1)
n_trip_per_atom = deg * (deg - 1)
total_triplets = int(n_trip_per_atom.sum().item())

if total_triplets == 0:
empty = torch.empty(0, dtype=torch.long, device=device)
return {
"trip_in": empty,
"trip_out": empty,
"trip_out_agg": empty,
"center_atom": empty,
}

# Atom ids that have at least 2 edges
active = deg >= 2
active_atoms = torch.where(active)[0]
active_deg = deg[active]
active_offsets = offsets[:-1][active]
active_n_trip = n_trip_per_atom[active]

# Expand: for each active atom, enumerate deg*(deg-1) triplets
atom_rep = torch.repeat_interleave(
torch.arange(active_atoms.size(0), device=device), active_n_trip
)
base_off = torch.repeat_interleave(active_offsets, active_n_trip)
d = torch.repeat_interleave(active_deg, active_n_trip)

# Local triplet index within each atom's group
local = _inner_idx(atom_rep, active_atoms.size(0))

# Map local index to (row, col) within the deg x (deg-1) grid
# row = local // (deg-1), col = local % (deg-1)
dm1 = d - 1
row = local // dm1
col = local % dm1
# Skip diagonal: if col >= row, shift col by 1
col = col + (col >= row).long()

trip_in = order[base_off + row]
trip_out = order[base_off + col]

# Center atom for each triplet
center = torch.repeat_interleave(active_atoms, active_n_trip)

# Aggregation index: local enumeration by trip_out
trip_out_agg = _inner_idx(trip_out, n_edges) if total_triplets > 0 else trip_out

return {
"trip_in": trip_in,
"trip_out": trip_out,
"trip_out_agg": trip_out_agg,
"center_atom": center,
}


def build_mixed_triplets(
edge_index_in: torch.Tensor,
edge_index_out: torch.Tensor,
n_atoms: int,
*,
to_outedge: bool = False,
cell_offsets_in: torch.Tensor | None = None,
cell_offsets_out: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Build triplet indices across two different edge sets sharing the same atoms.

For each edge in ``edge_index_out``, finds all edges in ``edge_index_in``
that share the same atom (target or source depending on *to_outedge*),
filtering self-loops via cell offsets when provided.

This is the pure-PyTorch equivalent of GemNet-OC ``get_mixed_triplets``.

Args:
edge_index_in: ``[2, n_edges_in]`` — input graph edges.
edge_index_out: ``[2, n_edges_out]`` — output graph edges.
n_atoms: Total number of atoms.
to_outedge: If True, match on the *source* atom of out-edges (``a→c``
style); otherwise match on the *target* atom (``c→a`` style).
cell_offsets_in: ``[n_edges_in, 3]`` periodic offsets for input graph.
cell_offsets_out: ``[n_edges_out, 3]`` periodic offsets for output graph.

Returns:
Dict with keys ``"trip_in"``, ``"trip_out"``, ``"trip_out_agg"``.
"""
src_in, tgt_in = edge_index_in[0], edge_index_in[1]
src_out, tgt_out = edge_index_out[0], edge_index_out[1]
n_edges_in = src_in.size(0)
n_edges_out = src_out.size(0)
device = src_in.device

# Build CSR of input edges grouped by target atom
order_in = torch.argsort(tgt_in, stable=True)
sorted_tgt_in = tgt_in[order_in]
deg_in = torch.bincount(sorted_tgt_in, minlength=n_atoms)
csr_in = torch.zeros(n_atoms + 1, dtype=torch.long, device=device)
csr_in[1:] = deg_in.cumsum(0)

# For each output edge, pick the shared atom
shared_atom = src_out if to_outedge else tgt_out

# Degree of each output edge's shared atom in the input graph
deg_per_out = deg_in[shared_atom] # [n_edges_out]

# Expand: repeat each output edge index by degree of its shared atom
trip_out = torch.repeat_interleave(
torch.arange(n_edges_out, device=device), deg_per_out
)
# For each expanded entry, the corresponding input edge
base_off = csr_in[shared_atom] # start offset into sorted input edges
base_off_exp = torch.repeat_interleave(base_off, deg_per_out)
deg_exp = torch.repeat_interleave(deg_per_out, deg_per_out)

# Local index within the group
local = _inner_idx(trip_out, n_edges_out)
trip_in = order_in[base_off_exp + local]

# Filter self-loops: atom-index check + cell offset check
if to_outedge:
idx_atom_in = src_in[trip_in]
idx_atom_out = tgt_out[trip_out]
else:
idx_atom_in = src_in[trip_in]
idx_atom_out = src_out[trip_out]

mask = idx_atom_in != idx_atom_out
if cell_offsets_in is not None and cell_offsets_out is not None:
if to_outedge:
cell_sum = cell_offsets_out[trip_out] + cell_offsets_in[trip_in]
else:
cell_sum = cell_offsets_out[trip_out] - cell_offsets_in[trip_in]
mask = mask | torch.any(cell_sum != 0, dim=-1)

trip_in = trip_in[mask]
trip_out = trip_out[mask]

trip_out_agg = _inner_idx(trip_out, n_edges_out)

return {
"trip_in": trip_in,
"trip_out": trip_out,
"trip_out_agg": trip_out_agg,
}


def build_quadruplets(
main_edge_index: torch.Tensor,
qint_edge_index: torch.Tensor,
n_atoms: int,
main_cell_offsets: torch.Tensor,
qint_cell_offsets: torch.Tensor,
) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Build quadruplet interaction indices ``d→b→a←c`` from two edge sets.

Combines an input triplet set ``d→b→a`` (edges from ``main_graph``
arriving at ``b`` for each ``qint_graph`` edge ``b→a``) with an output
triplet set ``c→a←b`` (``qint_graph`` edges arriving at ``a`` for each
``main_graph`` edge ``c→a``), then takes their Cartesian product over the
shared intermediate edge ``b→a``.

Pure-PyTorch equivalent of GemNet-OC ``get_quadruplets``.

Args:
main_edge_index: ``[2, n_main_edges]`` — main graph edges.
qint_edge_index: ``[2, n_qint_edges]`` — quadruplet interaction graph edges.
n_atoms: Total number of atoms.
main_cell_offsets: ``[n_main_edges, 3]`` cell offsets for main graph.
qint_cell_offsets: ``[n_qint_edges, 3]`` cell offsets for qint graph.

Returns:
Dict with keys:

- ``"triplet_in"`` — dict from :func:`build_mixed_triplets` for ``d→b→a``.
- ``"triplet_out"`` — dict from :func:`build_mixed_triplets` for ``c→a←b``.
- ``"quad_out"`` — main-edge indices of output edge ``c→a``, shape
``[n_quads]``.
- ``"trip_in_to_quad"`` — maps input-triplet index → quadruplet, shape
``[n_quads]``.
- ``"trip_out_to_quad"`` — maps output-triplet index → quadruplet, shape
``[n_quads]``.
- ``"quad_out_agg"`` — per-segment local index for aggregation, shape
``[n_quads]``.
"""
src_main = main_edge_index[0]
n_main_edges = src_main.size(0)
n_qint_edges = qint_edge_index.size(1)
device = src_main.device

# Input triplets: d→b→a (input=main, output=qint, to_outedge=True)
triplet_in = build_mixed_triplets(
main_edge_index,
qint_edge_index,
n_atoms,
to_outedge=True,
cell_offsets_in=main_cell_offsets,
cell_offsets_out=qint_cell_offsets,
)

# Output triplets: c→a←b (input=qint, output=main, to_outedge=False)
triplet_out = build_mixed_triplets(
qint_edge_index,
main_edge_index,
n_atoms,
to_outedge=False,
cell_offsets_in=qint_cell_offsets,
cell_offsets_out=main_cell_offsets,
)

# Count input triplets per intermediate (qint) edge
ones_in = torch.ones_like(triplet_in["trip_out"])
n_trip_in_per_inter = torch.zeros(n_qint_edges, dtype=torch.long, device=device)
n_trip_in_per_inter.index_add_(0, triplet_in["trip_out"], ones_in)

# For each output triplet, how many input triplets share the same
# intermediate edge → gives the repeat count for the cross-product
n_trip_out_per_inter = n_trip_in_per_inter[triplet_out["trip_in"]]

# Expand output-triplet side
quad_out = torch.repeat_interleave(triplet_out["trip_out"], n_trip_out_per_inter)
inter_edge = torch.repeat_interleave(triplet_out["trip_in"], n_trip_out_per_inter)
trip_out_to_quad = torch.repeat_interleave(
torch.arange(triplet_out["trip_out"].size(0), device=device),
n_trip_out_per_inter,
)

# Expand input-triplet side: for each intermediate edge in `inter_edge`,
# gather all input triplets belonging to that edge.
# Build CSR of input triplets grouped by their intermediate (qint) edge
ti_out = triplet_in["trip_out"] # sorted by construction
csr_ti = torch.zeros(n_qint_edges + 1, dtype=torch.long, device=device)
csr_ti[1:] = n_trip_in_per_inter.cumsum(0)

base = csr_ti[inter_edge]
local = _inner_idx(inter_edge, n_qint_edges)
# `local` cycles 0..count-1 within each repeated intermediate edge
# but we need to handle the Cartesian product properly:
# For each output triplet repeated `count` times, local goes 0..count-1
trip_in_to_quad = base + local

# Gather the actual input edge index
idx_in = triplet_in["trip_in"][trip_in_to_quad]

# Filter: c ≠ d (with cell offsets)
idx_atom_c = src_main[quad_out]
idx_atom_d = src_main[idx_in]

cell_offset_cd = (
main_cell_offsets[idx_in]
+ qint_cell_offsets[inter_edge]
- main_cell_offsets[quad_out]
)
mask = (idx_atom_c != idx_atom_d) | torch.any(cell_offset_cd != 0, dim=-1)

quad_out = quad_out[mask]
trip_out_to_quad = trip_out_to_quad[mask]
trip_in_to_quad = trip_in_to_quad[mask]

quad_out_agg = _inner_idx(quad_out, n_main_edges)

return {
"triplet_in": triplet_in,
"triplet_out": triplet_out,
"quad_out": quad_out,
"trip_in_to_quad": trip_in_to_quad,
"trip_out_to_quad": trip_out_to_quad,
"quad_out_agg": quad_out_agg,
}
Loading