Skip to content

unixsysdev/ann-sparseattention

Repository files navigation

ann-sparseattention

Train tiny per-layer "search projections" on a frozen LLM that replicate the attention's top-K preferences in a low-dimensional space, so we can swap dense quadratic attention for an off-the-shelf ANN index (FAISS HNSW) at inference and lose almost no model quality.

Current status

Research prototype. The trained projections preserve quality in a clean 6-layer block-causal WikiText-103 pilot on Qwen/Qwen3-4B-Instruct-2507. Broad substitution now has two clean datapoints: substituting all 36 layers is feasible but not full-attention parity, while a data-driven 32-layer run that reserves edge layers [0, 1, 2, 35] substantially reduces the quality cost. The runtime is still a correctness prototype. Treat reported numbers as preliminary until confidence intervals, downstream long-context tasks, and production baselines are run.

Extremely preliminary llama.cpp runtime branch. A separate runtime-only branch contains the current llama.cpp integration snapshot and raw CPU/ROCm smoke-test logs: feat/llama-ann-runtime. These numbers are not publication-strength and should not be treated as production performance. They are early engineering diagnostics: exact learned top-K decode shows an initial long-context speed signal at 33.6K tokens, while the HNSW path is only a correctness bridge that rebuilds indices during decode and is therefore intentionally slow. The branch also includes strict password recall probes showing that 6-layer HNSW recalls 1K/2K/4K passwords, while 32-layer and 36-layer HNSW are not yet reliable on that exact-recall test.

Checkpoint artifacts and JSON eval outputs are mirrored on Hugging Face: datasysdev/ann-sparseattention. Use checkpoints_block_d128/search_step_1000.pt there for the current clean 6-layer block-causal result. Use checkpoints_all36_d128_block/protected/search_step_500_keep.pt for the best all-36 checkpoint observed so far. Use checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.pt for the current broad-substitution checkpoint.

What's validated:

  • 6-layer packed pilot on Qwen3-4B-Instruct-2507, layers [4, 8, 12, 16, 20, 24], 4K context, 1K training steps.
  • d_search=128 is the current recommended capacity from the packed capacity ablation: 3.93M trainable parameters, mass@K=128 of 0.503 vs 0.488 for the raw-QK exact-topK oracle, and -1.81% relative PPL gap at K=128 on the packed eval slice.
  • Block-causal packed masking is implemented. On the clean block-causal d128 rerun, exact sparse attention is near parity with full attention (K=128: +0.07% PPL gap; K=256: +0.01%). The large negative PPL gaps from packed-with-leakage do not survive as a clean-methodology headline.
  • Capacity scaling is monotonic but saturating: d64 < d128 < d256 on mass@K, while d128 and d256 are effectively tied on final PPL.
  • Clean Quest-style and paired NLL baselines are implemented. Learned retrieval captures more teacher mass than Quest at equal K, but Quest is slightly better at K=128 PPL and statistically tied at K=256.
  • Clean FAISS/HNSW per-segment indexing matches exact learned retrieval closely at K=128/K=256, validating the ANN-compatibility claim.
  • All-36 d128 block-causal substitution was run to step 800. Best eval: step 500, recall@K=0.816, PPL gap +3.23%. Step 750 regressed to +3.96%. Per-layer mass@K shows the weakest layers are L00/L01/L02, with L35 only mildly weak.
  • The all32_d128_block follow-up reserves full attention on [0, 1, 2, 35] and trains layers 3..34. It converged by step 500 and finished step 1000 at recall@K=0.825 and +1.746% PPL gap in training eval. A post-hoc exact K-sweep on the final checkpoint gives +0.590% PPL gap at K=128 and -0.062% at K=256 on a 2-batch clean block-causal slice.

Not yet validated (next iteration):

  • Confidence intervals for the block-causal result over multiple seeds and larger eval slices.
  • Long-context task quality (LongBench, RULER, needle-in-haystack).
  • A controlled coverage Pareto sweep, especially 12-, 18-, and 20-layer configurations, to locate the zero-gap broad-substitution point.
  • Wall-clock speedup vs. FlashAttention/SDPA — not measured.
  • KV-cache decode-mode integration.
  • GPU-resident ANN or fused gather-attention kernel.

Runtime caveat. The current FAISS path is a correctness prototype: it builds a CPU index per forward pass and uses dense-style tensor expansion internally for the gather step. The compute-reduction numbers below are algorithmic scoring reductions, not measured wall-clock speedups. A production runtime requires a GPU-resident topk kernel or integration with paged/block-sparse attention kernels.

d_search ablation (packed WikiText-103, K=128)

The packed ablation trains the same 6 layers for 1K steps and evaluates all variants with the same packed eval pipeline. raw_qk is exact top-K over head-mean-aggregated native post-RoPE Q/K vectors; learned is exact top-K over trained search projections. mass@K is teacher-attention probability captured by the retrieved set.

d_search Params learned mass@K=128 raw-QK oracle learned / oracle Final PPL gap
64 1.97M 0.492 0.488 1.01x +2.39%
128 3.93M 0.503 0.488 1.03x -1.81%
256 7.86M 0.509 0.488 1.04x -1.85%

d128 is the recommended default for this pilot: it captures almost all of the d256 quality with half the trainable parameters. d256 improves mass@K slightly but does not materially improve final PPL.

PPL gap is the primary model-quality signal; mass@K is the more direct retrieval-quality signal when teacher attention is sharp. Recall@K is logged, but it is a weaker proxy because disagreement on near-zero-probability tail positions can look like low recall while preserving model output.

Per-layer mass@K=128 for d128:

Layer raw-QK oracle learned d128
4 0.422 0.382
8 0.518 0.421
12 0.404 0.533
16 0.475 0.481
20 0.499 0.551
24 0.614 0.648

Early layers remain harder for learned retrieval; mid/late trained layers exceed raw-QK oracle mass.

K-retrieve Pareto (packed d128, leakage-confounded)

Exact top-K sweep for the recommended packed d128 checkpoint:

python k_sweep.py \
  --ckpt /tmp/checkpoints_packed_d128/search_step_1000.pt \
  --K 128,256,512 \
  --no-use-faiss

PPL_full = 224.64 on this packed eval slice.

K Recall@K mass@K PPL_ANN PPL gap
128 0.166 0.256 203.63 -9.36%
256 0.233 0.318 207.06 -7.83%
512 0.339 0.409 211.93 -5.66%

This disambiguates the earlier FAISS high-K failure on the leaked packed pipeline: exact retrieval remains strongly negative at K=256/512, so the denoising pattern is present on this packed eval slice. This should not be used as a publication-strength denoising claim because packed examples can attend across document boundaries.

A second exact sweep on the next 16 packed eval batches (--skip-batches 16) preserved the shape: K=128 -8.78%, K=256 -7.59%, K=512 -6.21%. This is still not a substitute for confidence intervals, but it reduces the chance that the large negative gap is a single-slice accident.

Block-causal packed d128 (clean masking)

Packed block-causal masking assigns each packed document a segment_id, resets position_ids at segment boundaries, and supplies a 4D additive mask so tokens can only attend causally within their own document. Retrieval, loss masking, mass@K, and recall@K use the same segment-causal eligibility mask.

Clean d128 block-causal run:

python train.py --config pilot_d128_block
python k_sweep.py \
  --ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
  --K 128,256,512 \
  --no-use-faiss

PPL_full = 30.44 on the 16-batch clean eval slice.

K Recall@K mass@K PPL_ANN PPL gap
128 0.744 0.787 30.47 +0.07%
256 0.879 0.953 30.45 +0.01%
512 n/a n/a 30.45 +0.01%

K=512 has no meaningful mass/recall average on this WikiText slice because almost no same-segment queries have 512 valid causal keys. The quality result is still useful: with filler slots masked out of the sparse-attention softmax, the block-causal exact path is effectively at full-attention parity. The clean result supports "quality-preserving sparse substitution" rather than the leaked pipeline's stronger denoising claim.

Clean block-causal per-layer compare_retrieval at K=128:

Layer raw-QK oracle mass learned d128 mass
4 0.956 0.950
8 0.977 0.976
12 0.970 0.977
16 0.964 0.970
20 0.970 0.983
24 0.978 0.984
avg 0.969 0.973

This changes the per-layer interpretation from the leakage-confounded pilot: with segment isolation, early trained layers are not diffuse or uniquely hard. All six trained layers have high oracle mass, and learned projections match or slightly exceed raw-QK retrieval across the set. The deployment hypothesis for the next run is therefore "substitute all tested layers" rather than "keep early layers as full attention," pending a broader all-layer run.

Quest-style page baseline (clean block-causal)

quest_sweep.py implements a Quest-style min/max page selector for comparison: page size 16, native post-RoPE Q/K, same block-causal token eligibility mask, and the same sparse-attention gather path. This is a correctness baseline, not an optimized Quest runtime.

python quest_sweep.py \
  --ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
  --K 128,256,512 \
  --page-size 16

On the same 16-batch block-causal eval slice:

Method K Recall@K mass@K PPL PPL gap
learned search exact 128 0.744 0.787 30.47 +0.07%
Quest-style page 128 0.669 0.727 30.41 -0.11%
learned search exact 256 0.879 0.953 30.45 +0.01%
Quest-style page 256 0.838 0.909 30.45 +0.03%

Both methods are effectively full-attention parity on PPL. The learned search space recovers more teacher attention mass at the same token budget, especially at K=128, while Quest remains a strong non-trained heuristic baseline. This keeps the contribution narrow: learned projections improve retrieval fidelity and support standard ANN indexing; they do not yet show a clean PPL advantage over Quest on this slice.

Paired 32-batch NLL evaluation gives a sharper comparison:

K full PPL learned PPL Quest PPL learned - Quest NLL delta (95% bootstrap CI) Read
128 28.03 28.07 28.01 +0.00205 [+0.00160, +0.00251] Quest slightly better
256 28.03 28.04 28.04 -0.00005 [-0.00029, +0.00018] statistical tie

So the current clean result is: learned search has higher teacher-attention mass, but PPL is either tied with Quest (K=256) or slightly worse (K=128) on this paired WikiText slice. The paper claim should be "retrieval-fidelity and ANN-compatibility advantages," not "PPL advantage over Quest."

Clean FAISS-vs-exact check

The first block-causal FAISS prototype used one global index followed by segment filtering, which produced pathological filler rates after filtering. The current FAISS path builds per-segment indexes when a 4D block-causal mask is present. With that fix, CPU FAISS/HNSW tracks exact learned search on the same 16-batch clean eval slice:

Method K PPL PPL gap FAISS filler rate
learned exact 128 30.47 +0.07% n/a
learned FAISS/HNSW 128 30.47 +0.09% 0.447
learned exact 256 30.45 +0.01% n/a
learned FAISS/HNSW 256 30.46 +0.04% 0.683

The remaining filler rate is expected for short same-segment prefixes where fewer than K valid causal keys exist; filler slots are masked out of the sparse attention softmax. This demonstrates off-the-shelf ANN compatibility in the clean block-causal setting, but not production wall-clock speedup.

Asymptotic scoring analysis

artifacts/scaling_analysis.md gives a deterministic operation-count proxy for the per-query candidate scoring step. This is the cost of identifying which keys to attend to, before the sparse attention softmax and value multiply over the selected keys.

Assumptions:

  • Full attention scoring: N * d_head = N * 128.
  • Quest-style page scoring: (N / page_size) * 2 * d_head = N * 16 with page_size=16.
  • Learned HNSW scoring: M * ef_search * log2(N) * d_search with M=32, ef_search=64, and d_search=128.

Candidate-scoring operations per query

Context Full ops/query Quest ops/query Learned HNSW ops/query Quest / learned
4K 512,000 64,000 3,136,759 0.02x
8K 1,024,000 128,000 3,398,903 0.04x
16K 2,048,000 256,000 3,661,047 0.07x
32K 4,096,000 512,000 3,923,191 0.13x
64K 8,192,000 1,024,000 4,185,335 0.24x
128K 16,384,000 2,048,000 4,447,479 0.46x
256K 32,768,000 4,096,000 4,709,623 0.87x
512K 65,536,000 8,192,000 4,971,767 1.65x
1M 128,000,000 16,000,000 5,224,942 3.06x
2M 256,000,000 32,000,000 5,487,086 5.83x
4M 512,000,000 64,000,000 5,749,230 11.13x

Under these conservative HNSW constants, Quest is cheaper below the few-hundred-thousand-token regime and learned-projection scoring becomes cheaper beyond roughly 300K tokens. At 1M context, the operation-count proxy is about 3x in favor of learned projections. This supports the theoretical scaling claim only; production speed claims still require GPU-resident retrieval and KV-cache/decode integration.

Dynamic-index proxy

The current ANN wrapper is prefill-only (use_cache=False), so a true generation-time dynamic-index benchmark still requires cache integration. As a first capability proxy, dynamic_index_proxy.py splits clean block-causal eval sequences into a prefill prefix and decode-like suffix, then compares learned retrieval mass under two masks:

  • dynamic index: suffix queries can retrieve from all same-segment prior keys;
  • static index: suffix queries can retrieve from prefill keys plus a 256-token recent local suffix window, but not older suffix keys.

On the clean d128 block-causal checkpoint, using K=128, prefill length 1024, local window 256, and 8 eval batches:

Setting Teacher mass captured
Dynamic proxy 0.972
Static proxy 0.928
Static teacher mass available 0.954
Dynamic - static +0.044

Per-layer dynamic-minus-static mass ranges from +0.022 (L04) to +0.058 (L08). This does not establish task accuracy or decode latency, but it shows that a frozen prefill-plus-local index loses measurable teacher-attention mass on decode-like suffix queries. The raw result is in artifacts/dynamic_proxy_8b.json.

All-36 and all32 reserved-layer experiments

The first broad substitution run trained d_search=128 projections for all 36 attention layers under the clean block-causal pipeline:

python train.py --config all36_d128_block

It is feasible, but it is not yet the quality-preserving headline. Training was stopped after step 800 to inspect checkpoints and start a better targeted follow-up.

Step Recall@K eval PPL gap Read
250 0.805 +6.27% ranking still poor
500 0.816 +3.23% best checkpoint so far
750 0.817 +3.96% PPL regressed despite stable recall

The protected step-500 checkpoint is mirrored as:

checkpoints_all36_d128_block/protected/search_step_500_keep.pt

Per-layer compare_retrieval at K=128 shows high average retrieval fidelity but clear early-layer weakness:

Layer raw-QK learned delta
L00 0.922 0.780 -0.142
L01 0.918 0.851 -0.067
L02 0.939 0.899 -0.040
L03 0.939 0.924 -0.015
L04 0.944 0.933 -0.011
L05 0.964 0.947 -0.017
L06 0.956 0.936 -0.020
L07 0.982 0.982 +0.000
L08 0.971 0.970 -0.001
L09 0.959 0.976 +0.017
L10 0.974 0.970 -0.004
L11 0.976 0.975 -0.001
L12 0.961 0.969 +0.008
L13 0.971 0.971 +0.000
L14 0.973 0.973 -0.000
L15 0.968 0.972 +0.004
L16 0.956 0.962 +0.006
L17 0.959 0.966 +0.007
L18 0.965 0.972 +0.007
L19 0.961 0.968 +0.007
L20 0.959 0.975 +0.016
L21 0.966 0.979 +0.014
L22 0.963 0.970 +0.007
L23 0.979 0.984 +0.005
L24 0.971 0.978 +0.007
L25 0.986 0.988 +0.002
L26 0.978 0.985 +0.008
L27 0.978 0.983 +0.005
L28 0.979 0.985 +0.005
L29 0.982 0.987 +0.005
L30 0.988 0.986 -0.002
L31 0.984 0.984 -0.001
L32 0.979 0.979 +0.000
L33 0.977 0.970 -0.007
L34 0.976 0.960 -0.016
L35 0.980 0.967 -0.013
avg 0.966 0.960 -0.006

This motivated the reserved-edge follow-up run:

python train.py --config all32_d128_block

all32_d128_block reserves full attention on [0, 1, 2, 35] and trains layers 3..34. This tests whether broad substitution fails mainly because of the weak edge layers, or because small approximation errors compound across many otherwise-good layers. The run finished cleanly at step 1000 with 20.97M trained search-projection parameters:

Step Recall@K eval PPL gap Read
250 0.812 +2.283% already better than all36 best training eval
500 0.823 +1.753% converged to the final quality band
750 0.825 +1.943% small eval fluctuation
1000 0.825 +1.746% final checkpoint; essentially tied with step 500

The all32 result is the current broad-substitution headline. It is not full-attention parity at K=128 in training eval, but it cuts the all36 quality cost roughly in half while still substituting 32 of 36 layers. Post-hoc compare_retrieval on the final checkpoint shows the reserved-edge hypothesis did what it was supposed to do: on the substituted layers, learned retrieval matches raw-QK retrieval mass (K=128 learned 0.971 vs raw 0.969; K=256 learned 0.993 vs raw 0.994). The remaining PPL cost is therefore more likely compound approximation error than a single bad substituted layer.

Exact K-sweep on the final all32 checkpoint, 2-batch clean block-causal slice (PPL_full = 20.5349):

K mass@K Recall@K sparse PPL PPL gap
16 0.546 0.518 24.86 +21.064%
32 0.627 0.572 21.85 +6.422%
64 0.722 0.652 20.94 +1.974%
128 0.807 0.746 20.66 +0.590%
256 0.902 0.876 20.52 -0.062%

K=512 is intentionally omitted from this table. The current script produced a valid sparse-attention PPL line for K=512 but zero mass/recall, which is an edge-case bug in the metric path when K exceeds the number of valid causal keys for most same-segment queries. It should be rerun after fixing the metric handling; the publishable sweep for now is K <= 256.

The emerging coverage picture is more useful than a single number:

Configuration Layers substituted Coverage PPL gap Read
6-layer clean pilot 6/36 17% +0.07% at K=128 quality-preserving pilot
all32 reserved-edge 32/36 89% +1.746% train eval; +0.590% exact sweep near-parity broad substitution
all36 36/36 100% +3.23% best observed full substitution costs quality

This suggests layer coverage is itself a Pareto knob. The current data is not enough to claim an optimal coverage ratio, but it strongly suggests the best deployment point is intermediate rather than "sparsify everything." A 12/18/20 layer coverage sweep is the next clean experiment.

Compute / quality knobs (FLOP-counted)

L = 4096. Compute reduction is the attention scoring step, ≈ L / K. These are FLOP estimates, not measured wall-clock — the FAISS path in this repo is a research prototype that does CPU index builds and GPU↔CPU transfers, so it is not the right thing to time. A GPU-resident topk kernel is the natural next step.

K PPL gap Attention scoring reduction
512 -5.66% (exact top-K over learned search space) ~8x
256 -7.83% (exact top-K over learned search space) ~16x
128 -9.36% exact; -1.81% FAISS/training eval ~32x
64 +0.46% ~64x
32 +0.03% ~128x
16 +5.63% ~256x

Eval scope for the d_search table: 16 packed validation batches at 4K context for PPL/recall during training, and 12 packed batches for compare_retrieval mass@K. Numbers should be read as "what we observed on this slice", not population-level estimates.

Caveats / what's next

A few things the pilot does not yet establish, and that the next iteration will:

  • Packing: the d_search ablation table is still from the packed leakage-confounded run and is best read as a capacity comparison. The clean block-causal d128 rerun removes cross-document leakage and should be used for quality claims.
  • Exact-topK oracle: the obvious follow-up is a four-way Pareto — full attention vs. exact top-K (true QK^T argmax-K, then attention) vs. search-topK (our projections, exact distance) vs. search-ANN (FAISS HNSW). That separates "denoising from any sparsity" from "denoising from learned projections."
  • Wall-clock: the compute-reduction table above is FLOP-counted. The FAISS path here is a research prototype (CPU index per forward, GPU↔CPU transfer) and is the wrong thing to time. A GPU-resident topk kernel is the next-step engineering.
  • Broad substitution: all-36 is viable but not parity (+3.23% best observed PPL gap). The all32 reserved-edge run reduces the cost to +1.746% in training eval and +0.590% at K=128 in the post-hoc exact sweep, with parity at K=256 on the small sweep.

The recall@K and mass@K reported here come from a 12-batch eval slice, not a population-level estimate. Confidence intervals and downstream tasks (LongBench / RULER / needle-in-haystack) are the natural next evals.

Broad layer runs

Two broad-layer configs are now wired:

  • all36_d128_block: trains all 36 layers, clean block-causal, d128. Best observed checkpoint is step 500 at +3.23% PPL gap.
  • all32_d128_block: trains layers 3..34, reserves [0, 1, 2, 35] as full attention. Final step-1000 checkpoint: recall@K=0.825 and +1.746% PPL gap in training eval; exact sweep reaches +0.590% at K=128 and parity at K=256 on a 2-batch clean slice.

Checkpoints are mirrored at datasysdev/ann-sparseattention.

Relation to RetrievalAttention

The closest prior work is RetrievalAttention (Liu et al., 2024). They show that vanilla ANN over the model's native Q and K vectors fails because Q and K live in mismatched distributions — they were never trained to be each other's nearest neighbors, only to score correctly via the dot product. Their fix is at index time: an attention-aware graph construction (RoarGraph-style) that compensates for the Q/K out-of- distribution problem at search time.

This work attacks the same problem from the opposite direction. Instead of patching the index over hostile vectors, we train a tiny shared low-dimensional projection (W_Qs, W_Ks → R^128 in the recommended pilot) so that q_search and k_search do live in the same distribution by construction. Off-the- shelf FAISS HNSW with default parameters is then sufficient — there is no attention-aware index trick.

Search space Index Trainable
Raw Q/K + vanilla ANN original Q/K off-the-shelf no — fails (Q/K OOD)
RetrievalAttention original Q/K attention-aware graph no
This work learned Q_s / K_s off-the-shelf yes (~2-11M params)

The contribution claim: eliminate the Q/K mismatch at index-build time via distillation, instead of patching it at search time. The clean experiment to validate this — vanilla FAISS over raw Q/K vs. vanilla FAISS over learned Q_s/K_s vs. exact teacher top-K, all at the same K — is the next planned run. The current pilot establishes that the learned projections retrieve attention-relevant keys; the comparison run isolates how much of that came from the projection vs. the ANN approximation.

Relation to other efficient attention methods

The paper now frames the method against two kinds of competitors:

  • Closest asymptotic ancestor: Reformer. Reformer uses untrained LSH hashing to find candidate keys, then attends within buckets. This work keeps the retrieve-then-attend shape but trains the retrieval space from teacher attention, instead of relying on random hashes over native Q/K.
  • Closest practical baseline: Quest. Quest is query-aware and strong at moderate context, but scans page metadata linearly in the number of pages. This work is weaker than Quest at K=128 PPL on the current slice, tied at K=256, but has a different long-context scaling target via ANN retrieval.
  • Linear/smooth approximations: Performer. Performer changes the attention computation with random feature approximations to softmax. This work preserves exact softmax over the retrieved set and approximates only candidate selection.
  • Fixed sparse patterns: Longformer/BigBird. These are efficient but not fully query-aware in the sense used here; remote keys are not selected because the current query needs them.

Qualitative property table from the paper:

Method Selection mechanism Query-aware Trained Asymptotic Exact softmax
Full attention all keys n/a n/a O(N²) yes
Reformer LSH hashing yes no O(N log N) over bucket
Performer random features n/a no O(N) no
BigBird window + random + global mostly no no O(N) over pattern
Longformer sliding window + global mostly no no O(N) over pattern
NSA-style methods block compression/selection partial partial O(N²) proxy yes
Quest min/max page heuristic yes no O(N) over pages
This work trained low-dim retrieval yes yes O(N log N) over retrieved set

This is a design-positioning table, not a completed empirical win. The current clean results prove the row for the six-layer pilot and show that broad substitution becomes usable when weak edge layers are reserved. All-layer quality is not yet parity.

The method also targets a different deployment scenario than native sliding-window or state-space/hybrid architectures such as Mistral-style sliding window, Mamba, or Qwen3.6 Gated DeltaNet hybrids. Those approaches are trained from scratch with their sparse or hybrid mechanism in place. This work is post-hoc: train a base model with full attention for maximum expressivity, then add lightweight retrieval projections afterward to make inference sub-linear without changing base weights. The intended benefit is decoupling training-time architecture from inference-time architecture.

How it works

For each full-attention layer i we train two linear projections W_Qs^i, W_Ks^i ∈ R^{d_model × d_search} (recommended pilot: d_search=128), so that for any hidden state h,

q_search = W_Qs^i h        k_search = W_Ks^i h
softmax(q_search · k_search^T)  ranks the same keys as
softmax(QK^T / √d_head)         (the teacher's attention)

Two losses, summed across layers:

  • InfoNCE with teacher-derived positives (top-K_pos keys from the teacher's attention serve as positives for each query).
  • KL(teacher ‖ student) on the full attention distribution.

At inference, we monkey-patch each trained layer's attention forward to:

  1. Compute q_search, k_search from the same hidden state.
  2. Build a per-batch FAISS HNSW index over k_search (default params).
  3. Retrieve top-K_retrieve positions (causal-respecting) per query.
  4. Run standard attention restricted to those K_retrieve keys.

The base model's parameters are never touched. The recommended d128 pilot trains 3.93M parameters total.

Repo layout

config.py        Run config (pilot defaults; make_headline_config() for follow-up)
model.py         SearchProjection, FrozenForwardCapture (with QK reconstruction
                 trick: capture (Q, K) post-RoPE while the forward stays in FA),
                 contrastive + KL distillation losses
data.py          Long-context dataloader (packing off by default to avoid
                 cross-segment attention leakage; pin_memory, prefetch)
inference.py     ANN-substituted attention (exact top-K for analysis;
                 CPU-FAISS HNSW prototype path — not a deployable kernel)
eval.py          recall@K curve, mass@K curve, full-vs-ANN PPL,
                 MoE router stability
train.py         Training loop, Liger setup, FA-3→FA-2→SDPA→eager fallback,
                 base-model freeze + drift check, auto-resume from latest ckpt
tests/           QK reconstruction verification + 50-step smoke test

Quick start

pip install -r requirements.txt
export WANDB_API_KEY=<key>      # only — never check it in
export HF_TOKEN=<token>         # for faster Hub downloads

# Pre-launch checks
python -c "from transformers import AutoConfig; \
print(AutoConfig.from_pretrained('Qwen/Qwen3-4B-Instruct-2507'))"
python tests/test_qk_reconstruction.py
python tests/smoke_test.py

# Packed d_search ablation
bash scripts/run_packed_ablation.sh

# Default clean pilot (packing off; data-sparse on WikiText articles)
python train.py --config pilot_d64_clean

Configuration

The default Config is the 1-day pilot:

Knob Pilot Broad clean runs
seq_len 4096 8192
batch_size 8 2 with grad accumulation 4
total_steps 1000 1000 for current all36/all32 pilots
layers trained 6 ([4,8,12,16,20,24]) all36 or all32 ([3..34])
trainable params 1.97M at d64; 3.93M at d128 23.59M all36 d128; ~20.97M all32 d128
d_search 64 default; d128 recommended from ablation d128
K_retrieve_eval 128 128

Pilot is the proof-of-concept. The broad clean runs test whether the technique scales from a six-layer subset to dense application across most or all layers.

Use make_pilot_d128_packed_config() to reproduce the current recommended packed capacity pilot, pilot_d128_block for the clean six-layer result, all36_d128_block for all-layer substitution, or all32_d128_block for the data-driven reserved-edge broad-substitution run.

Performance choices

  • attn_implementation resolves at load time as flash_attention_3 → flash_attention_2 → sdpa → eager. On B200 with no flash-attn package installed, SDPA wins — its built-in flash backend is ~80-90% of FA-2's throughput with zero build dependency.
  • Liger kernels applied via apply_liger_kernel_to_qwen3 (RMSNorm, SwiGLU, RoPE fused — typically 30-50% faster forward).
  • The QK-reconstruction trick keeps SDPA/FA fast on the trained layers: we monkey-patch them to capture (Q, K) post-RoPE, then reconstruct softmax(QK^T/√d) ourselves after the forward returns. The forward never sets output_attentions=True (which would force eager).
  • torch.compile(search_module, mode="max-autotune") on the search projections; base model uncompiled (works but flaky for novel architectures).
  • bf16 throughout; loss math cast to fp32 for numerical stability of softmax.

Verifying the QK reconstruction

The post-RoPE Q/K capture must match what the model's eager attention computes or distillation supervision is wrong. The test asserts top-32 agreement

99% per layer:

python tests/test_qk_reconstruction.py --model Qwen/Qwen3-4B-Instruct-2507
# layer 0: PASS  max|Δ|=2.54e-02  top-32 agree=0.9963
# layer 1: PASS  max|Δ|=5.27e-02  top-32 agree=0.9941
# ...
# QK reconstruction verified.

The bf16 max-abs differences (~0.05) are just numerical noise; the ranking of attention positions matches.

Reproducing the pilot

git clone git@github.com:unixsysdev/ann-sparseattention.git
cd ann-sparseattention
pip install -r requirements.txt
python train.py --config pilot_d128_packed

A single H100/H200/B200 + 8GB GPU RAM for the 4B model + ~10GB for activations at 4K context, batch 8.

License

MIT.

About

Make Attention Sub-Quadratic Again

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors