Train tiny per-layer "search projections" on a frozen LLM that replicate the attention's top-K preferences in a low-dimensional space, so we can swap dense quadratic attention for an off-the-shelf ANN index (FAISS HNSW) at inference and lose almost no model quality.
Research prototype. The trained projections preserve quality in a clean
6-layer block-causal WikiText-103 pilot on
Qwen/Qwen3-4B-Instruct-2507. Broad substitution now has two clean datapoints:
substituting all 36 layers is feasible but not full-attention parity, while a
data-driven 32-layer run that reserves edge layers [0, 1, 2, 35] substantially
reduces the quality cost. The runtime is still a correctness prototype. Treat
reported numbers as preliminary until confidence intervals, downstream
long-context tasks, and production baselines are run.
Extremely preliminary llama.cpp runtime branch. A separate runtime-only
branch contains the current llama.cpp integration snapshot and raw CPU/ROCm
smoke-test logs:
feat/llama-ann-runtime.
These numbers are not publication-strength and should not be treated as
production performance. They are early engineering diagnostics: exact learned
top-K decode shows an initial long-context speed signal at 33.6K tokens, while
the HNSW path is only a correctness bridge that rebuilds indices during decode
and is therefore intentionally slow. The branch also includes strict password
recall probes showing that 6-layer HNSW recalls 1K/2K/4K passwords, while
32-layer and 36-layer HNSW are not yet reliable on that exact-recall test.
Checkpoint artifacts and JSON eval outputs are mirrored on Hugging Face:
datasysdev/ann-sparseattention.
Use checkpoints_block_d128/search_step_1000.pt there for the current clean
6-layer block-causal result. Use
checkpoints_all36_d128_block/protected/search_step_500_keep.pt for the best
all-36 checkpoint observed so far. Use
checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.pt for the
current broad-substitution checkpoint.
What's validated:
- 6-layer packed pilot on Qwen3-4B-Instruct-2507, layers
[4, 8, 12, 16, 20, 24], 4K context, 1K training steps. d_search=128is the current recommended capacity from the packed capacity ablation: 3.93M trainable parameters, mass@K=128 of 0.503 vs 0.488 for the raw-QK exact-topK oracle, and -1.81% relative PPL gap at K=128 on the packed eval slice.- Block-causal packed masking is implemented. On the clean block-causal d128 rerun, exact sparse attention is near parity with full attention (K=128: +0.07% PPL gap; K=256: +0.01%). The large negative PPL gaps from packed-with-leakage do not survive as a clean-methodology headline.
- Capacity scaling is monotonic but saturating: d64 < d128 < d256 on mass@K, while d128 and d256 are effectively tied on final PPL.
- Clean Quest-style and paired NLL baselines are implemented. Learned retrieval captures more teacher mass than Quest at equal K, but Quest is slightly better at K=128 PPL and statistically tied at K=256.
- Clean FAISS/HNSW per-segment indexing matches exact learned retrieval closely at K=128/K=256, validating the ANN-compatibility claim.
- All-36 d128 block-causal substitution was run to step 800. Best eval: step 500, recall@K=0.816, PPL gap +3.23%. Step 750 regressed to +3.96%. Per-layer mass@K shows the weakest layers are L00/L01/L02, with L35 only mildly weak.
- The
all32_d128_blockfollow-up reserves full attention on[0, 1, 2, 35]and trains layers3..34. It converged by step 500 and finished step 1000 at recall@K=0.825 and +1.746% PPL gap in training eval. A post-hoc exact K-sweep on the final checkpoint gives +0.590% PPL gap at K=128 and -0.062% at K=256 on a 2-batch clean block-causal slice.
Not yet validated (next iteration):
- Confidence intervals for the block-causal result over multiple seeds and larger eval slices.
- Long-context task quality (LongBench, RULER, needle-in-haystack).
- A controlled coverage Pareto sweep, especially 12-, 18-, and 20-layer configurations, to locate the zero-gap broad-substitution point.
- Wall-clock speedup vs. FlashAttention/SDPA — not measured.
- KV-cache decode-mode integration.
- GPU-resident ANN or fused gather-attention kernel.
Runtime caveat. The current FAISS path is a correctness prototype: it builds a CPU index per forward pass and uses dense-style tensor expansion internally for the gather step. The compute-reduction numbers below are algorithmic scoring reductions, not measured wall-clock speedups. A production runtime requires a GPU-resident topk kernel or integration with paged/block-sparse attention kernels.
The packed ablation trains the same 6 layers for 1K steps and evaluates all
variants with the same packed eval pipeline. raw_qk is exact top-K over
head-mean-aggregated native post-RoPE Q/K vectors; learned is exact top-K
over trained search projections. mass@K is teacher-attention probability
captured by the retrieved set.
| d_search | Params | learned mass@K=128 | raw-QK oracle | learned / oracle | Final PPL gap |
|---|---|---|---|---|---|
| 64 | 1.97M | 0.492 | 0.488 | 1.01x | +2.39% |
| 128 | 3.93M | 0.503 | 0.488 | 1.03x | -1.81% |
| 256 | 7.86M | 0.509 | 0.488 | 1.04x | -1.85% |
d128 is the recommended default for this pilot: it captures almost all of the d256 quality with half the trainable parameters. d256 improves mass@K slightly but does not materially improve final PPL.
PPL gap is the primary model-quality signal; mass@K is the more direct retrieval-quality signal when teacher attention is sharp. Recall@K is logged, but it is a weaker proxy because disagreement on near-zero-probability tail positions can look like low recall while preserving model output.
Per-layer mass@K=128 for d128:
| Layer | raw-QK oracle | learned d128 |
|---|---|---|
| 4 | 0.422 | 0.382 |
| 8 | 0.518 | 0.421 |
| 12 | 0.404 | 0.533 |
| 16 | 0.475 | 0.481 |
| 20 | 0.499 | 0.551 |
| 24 | 0.614 | 0.648 |
Early layers remain harder for learned retrieval; mid/late trained layers exceed raw-QK oracle mass.
Exact top-K sweep for the recommended packed d128 checkpoint:
python k_sweep.py \
--ckpt /tmp/checkpoints_packed_d128/search_step_1000.pt \
--K 128,256,512 \
--no-use-faissPPL_full = 224.64 on this packed eval slice.
| K | Recall@K | mass@K | PPL_ANN | PPL gap |
|---|---|---|---|---|
| 128 | 0.166 | 0.256 | 203.63 | -9.36% |
| 256 | 0.233 | 0.318 | 207.06 | -7.83% |
| 512 | 0.339 | 0.409 | 211.93 | -5.66% |
This disambiguates the earlier FAISS high-K failure on the leaked packed pipeline: exact retrieval remains strongly negative at K=256/512, so the denoising pattern is present on this packed eval slice. This should not be used as a publication-strength denoising claim because packed examples can attend across document boundaries.
A second exact sweep on the next 16 packed eval batches (--skip-batches 16)
preserved the shape: K=128 -8.78%, K=256 -7.59%, K=512 -6.21%. This is still
not a substitute for confidence intervals, but it reduces the chance that the
large negative gap is a single-slice accident.
Packed block-causal masking assigns each packed document a segment_id, resets
position_ids at segment boundaries, and supplies a 4D additive mask so tokens
can only attend causally within their own document. Retrieval, loss masking,
mass@K, and recall@K use the same segment-causal eligibility mask.
Clean d128 block-causal run:
python train.py --config pilot_d128_block
python k_sweep.py \
--ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
--K 128,256,512 \
--no-use-faissPPL_full = 30.44 on the 16-batch clean eval slice.
| K | Recall@K | mass@K | PPL_ANN | PPL gap |
|---|---|---|---|---|
| 128 | 0.744 | 0.787 | 30.47 | +0.07% |
| 256 | 0.879 | 0.953 | 30.45 | +0.01% |
| 512 | n/a | n/a | 30.45 | +0.01% |
K=512 has no meaningful mass/recall average on this WikiText slice because almost no same-segment queries have 512 valid causal keys. The quality result is still useful: with filler slots masked out of the sparse-attention softmax, the block-causal exact path is effectively at full-attention parity. The clean result supports "quality-preserving sparse substitution" rather than the leaked pipeline's stronger denoising claim.
Clean block-causal per-layer compare_retrieval at K=128:
| Layer | raw-QK oracle mass | learned d128 mass |
|---|---|---|
| 4 | 0.956 | 0.950 |
| 8 | 0.977 | 0.976 |
| 12 | 0.970 | 0.977 |
| 16 | 0.964 | 0.970 |
| 20 | 0.970 | 0.983 |
| 24 | 0.978 | 0.984 |
| avg | 0.969 | 0.973 |
This changes the per-layer interpretation from the leakage-confounded pilot: with segment isolation, early trained layers are not diffuse or uniquely hard. All six trained layers have high oracle mass, and learned projections match or slightly exceed raw-QK retrieval across the set. The deployment hypothesis for the next run is therefore "substitute all tested layers" rather than "keep early layers as full attention," pending a broader all-layer run.
quest_sweep.py implements a Quest-style min/max page selector for comparison:
page size 16, native post-RoPE Q/K, same block-causal token eligibility mask,
and the same sparse-attention gather path. This is a correctness baseline, not
an optimized Quest runtime.
python quest_sweep.py \
--ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
--K 128,256,512 \
--page-size 16On the same 16-batch block-causal eval slice:
| Method | K | Recall@K | mass@K | PPL | PPL gap |
|---|---|---|---|---|---|
| learned search exact | 128 | 0.744 | 0.787 | 30.47 | +0.07% |
| Quest-style page | 128 | 0.669 | 0.727 | 30.41 | -0.11% |
| learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
| Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% |
Both methods are effectively full-attention parity on PPL. The learned search space recovers more teacher attention mass at the same token budget, especially at K=128, while Quest remains a strong non-trained heuristic baseline. This keeps the contribution narrow: learned projections improve retrieval fidelity and support standard ANN indexing; they do not yet show a clean PPL advantage over Quest on this slice.
Paired 32-batch NLL evaluation gives a sharper comparison:
| K | full PPL | learned PPL | Quest PPL | learned - Quest NLL delta (95% bootstrap CI) | Read |
|---|---|---|---|---|---|
| 128 | 28.03 | 28.07 | 28.01 | +0.00205 [+0.00160, +0.00251] |
Quest slightly better |
| 256 | 28.03 | 28.04 | 28.04 | -0.00005 [-0.00029, +0.00018] |
statistical tie |
So the current clean result is: learned search has higher teacher-attention mass, but PPL is either tied with Quest (K=256) or slightly worse (K=128) on this paired WikiText slice. The paper claim should be "retrieval-fidelity and ANN-compatibility advantages," not "PPL advantage over Quest."
The first block-causal FAISS prototype used one global index followed by segment filtering, which produced pathological filler rates after filtering. The current FAISS path builds per-segment indexes when a 4D block-causal mask is present. With that fix, CPU FAISS/HNSW tracks exact learned search on the same 16-batch clean eval slice:
| Method | K | PPL | PPL gap | FAISS filler rate |
|---|---|---|---|---|
| learned exact | 128 | 30.47 | +0.07% | n/a |
| learned FAISS/HNSW | 128 | 30.47 | +0.09% | 0.447 |
| learned exact | 256 | 30.45 | +0.01% | n/a |
| learned FAISS/HNSW | 256 | 30.46 | +0.04% | 0.683 |
The remaining filler rate is expected for short same-segment prefixes where fewer than K valid causal keys exist; filler slots are masked out of the sparse attention softmax. This demonstrates off-the-shelf ANN compatibility in the clean block-causal setting, but not production wall-clock speedup.
artifacts/scaling_analysis.md gives a deterministic operation-count proxy
for the per-query candidate scoring step. This is the cost of identifying
which keys to attend to, before the sparse attention softmax and value
multiply over the selected keys.
Assumptions:
- Full attention scoring:
N * d_head = N * 128. - Quest-style page scoring:
(N / page_size) * 2 * d_head = N * 16withpage_size=16. - Learned HNSW scoring:
M * ef_search * log2(N) * d_searchwithM=32,ef_search=64, andd_search=128.
| Context | Full ops/query | Quest ops/query | Learned HNSW ops/query | Quest / learned |
|---|---|---|---|---|
| 4K | 512,000 | 64,000 | 3,136,759 | 0.02x |
| 8K | 1,024,000 | 128,000 | 3,398,903 | 0.04x |
| 16K | 2,048,000 | 256,000 | 3,661,047 | 0.07x |
| 32K | 4,096,000 | 512,000 | 3,923,191 | 0.13x |
| 64K | 8,192,000 | 1,024,000 | 4,185,335 | 0.24x |
| 128K | 16,384,000 | 2,048,000 | 4,447,479 | 0.46x |
| 256K | 32,768,000 | 4,096,000 | 4,709,623 | 0.87x |
| 512K | 65,536,000 | 8,192,000 | 4,971,767 | 1.65x |
| 1M | 128,000,000 | 16,000,000 | 5,224,942 | 3.06x |
| 2M | 256,000,000 | 32,000,000 | 5,487,086 | 5.83x |
| 4M | 512,000,000 | 64,000,000 | 5,749,230 | 11.13x |
Under these conservative HNSW constants, Quest is cheaper below the few-hundred-thousand-token regime and learned-projection scoring becomes cheaper beyond roughly 300K tokens. At 1M context, the operation-count proxy is about 3x in favor of learned projections. This supports the theoretical scaling claim only; production speed claims still require GPU-resident retrieval and KV-cache/decode integration.
The current ANN wrapper is prefill-only (use_cache=False), so a true
generation-time dynamic-index benchmark still requires cache integration. As a
first capability proxy, dynamic_index_proxy.py splits clean block-causal eval
sequences into a prefill prefix and decode-like suffix, then compares learned
retrieval mass under two masks:
- dynamic index: suffix queries can retrieve from all same-segment prior keys;
- static index: suffix queries can retrieve from prefill keys plus a 256-token recent local suffix window, but not older suffix keys.
On the clean d128 block-causal checkpoint, using K=128, prefill length 1024, local window 256, and 8 eval batches:
| Setting | Teacher mass captured |
|---|---|
| Dynamic proxy | 0.972 |
| Static proxy | 0.928 |
| Static teacher mass available | 0.954 |
| Dynamic - static | +0.044 |
Per-layer dynamic-minus-static mass ranges from +0.022 (L04) to +0.058 (L08).
This does not establish task accuracy or decode latency, but it shows that a
frozen prefill-plus-local index loses measurable teacher-attention mass on
decode-like suffix queries. The raw result is in
artifacts/dynamic_proxy_8b.json.
The first broad substitution run trained d_search=128 projections for all 36
attention layers under the clean block-causal pipeline:
python train.py --config all36_d128_blockIt is feasible, but it is not yet the quality-preserving headline. Training was stopped after step 800 to inspect checkpoints and start a better targeted follow-up.
| Step | Recall@K eval | PPL gap | Read |
|---|---|---|---|
| 250 | 0.805 | +6.27% | ranking still poor |
| 500 | 0.816 | +3.23% | best checkpoint so far |
| 750 | 0.817 | +3.96% | PPL regressed despite stable recall |
The protected step-500 checkpoint is mirrored as:
checkpoints_all36_d128_block/protected/search_step_500_keep.pt
Per-layer compare_retrieval at K=128 shows high average retrieval fidelity
but clear early-layer weakness:
| Layer | raw-QK | learned | delta |
|---|---|---|---|
| L00 | 0.922 | 0.780 | -0.142 |
| L01 | 0.918 | 0.851 | -0.067 |
| L02 | 0.939 | 0.899 | -0.040 |
| L03 | 0.939 | 0.924 | -0.015 |
| L04 | 0.944 | 0.933 | -0.011 |
| L05 | 0.964 | 0.947 | -0.017 |
| L06 | 0.956 | 0.936 | -0.020 |
| L07 | 0.982 | 0.982 | +0.000 |
| L08 | 0.971 | 0.970 | -0.001 |
| L09 | 0.959 | 0.976 | +0.017 |
| L10 | 0.974 | 0.970 | -0.004 |
| L11 | 0.976 | 0.975 | -0.001 |
| L12 | 0.961 | 0.969 | +0.008 |
| L13 | 0.971 | 0.971 | +0.000 |
| L14 | 0.973 | 0.973 | -0.000 |
| L15 | 0.968 | 0.972 | +0.004 |
| L16 | 0.956 | 0.962 | +0.006 |
| L17 | 0.959 | 0.966 | +0.007 |
| L18 | 0.965 | 0.972 | +0.007 |
| L19 | 0.961 | 0.968 | +0.007 |
| L20 | 0.959 | 0.975 | +0.016 |
| L21 | 0.966 | 0.979 | +0.014 |
| L22 | 0.963 | 0.970 | +0.007 |
| L23 | 0.979 | 0.984 | +0.005 |
| L24 | 0.971 | 0.978 | +0.007 |
| L25 | 0.986 | 0.988 | +0.002 |
| L26 | 0.978 | 0.985 | +0.008 |
| L27 | 0.978 | 0.983 | +0.005 |
| L28 | 0.979 | 0.985 | +0.005 |
| L29 | 0.982 | 0.987 | +0.005 |
| L30 | 0.988 | 0.986 | -0.002 |
| L31 | 0.984 | 0.984 | -0.001 |
| L32 | 0.979 | 0.979 | +0.000 |
| L33 | 0.977 | 0.970 | -0.007 |
| L34 | 0.976 | 0.960 | -0.016 |
| L35 | 0.980 | 0.967 | -0.013 |
| avg | 0.966 | 0.960 | -0.006 |
This motivated the reserved-edge follow-up run:
python train.py --config all32_d128_blockall32_d128_block reserves full attention on [0, 1, 2, 35] and trains
layers 3..34. This tests whether broad substitution fails mainly because of
the weak edge layers, or because small approximation errors compound across
many otherwise-good layers. The run finished cleanly at step 1000 with 20.97M
trained search-projection parameters:
| Step | Recall@K eval | PPL gap | Read |
|---|---|---|---|
| 250 | 0.812 | +2.283% | already better than all36 best training eval |
| 500 | 0.823 | +1.753% | converged to the final quality band |
| 750 | 0.825 | +1.943% | small eval fluctuation |
| 1000 | 0.825 | +1.746% | final checkpoint; essentially tied with step 500 |
The all32 result is the current broad-substitution headline. It is not
full-attention parity at K=128 in training eval, but it cuts the all36 quality
cost roughly in half while still substituting 32 of 36 layers. Post-hoc
compare_retrieval on the final checkpoint shows the reserved-edge hypothesis
did what it was supposed to do: on the substituted layers, learned retrieval
matches raw-QK retrieval mass (K=128 learned 0.971 vs raw 0.969; K=256 learned
0.993 vs raw 0.994). The remaining PPL cost is therefore more likely compound
approximation error than a single bad substituted layer.
Exact K-sweep on the final all32 checkpoint, 2-batch clean block-causal slice
(PPL_full = 20.5349):
| K | mass@K | Recall@K | sparse PPL | PPL gap |
|---|---|---|---|---|
| 16 | 0.546 | 0.518 | 24.86 | +21.064% |
| 32 | 0.627 | 0.572 | 21.85 | +6.422% |
| 64 | 0.722 | 0.652 | 20.94 | +1.974% |
| 128 | 0.807 | 0.746 | 20.66 | +0.590% |
| 256 | 0.902 | 0.876 | 20.52 | -0.062% |
K=512 is intentionally omitted from this table. The current script produced a valid sparse-attention PPL line for K=512 but zero mass/recall, which is an edge-case bug in the metric path when K exceeds the number of valid causal keys for most same-segment queries. It should be rerun after fixing the metric handling; the publishable sweep for now is K <= 256.
The emerging coverage picture is more useful than a single number:
| Configuration | Layers substituted | Coverage | PPL gap | Read |
|---|---|---|---|---|
| 6-layer clean pilot | 6/36 | 17% | +0.07% at K=128 | quality-preserving pilot |
| all32 reserved-edge | 32/36 | 89% | +1.746% train eval; +0.590% exact sweep | near-parity broad substitution |
| all36 | 36/36 | 100% | +3.23% best observed | full substitution costs quality |
This suggests layer coverage is itself a Pareto knob. The current data is not enough to claim an optimal coverage ratio, but it strongly suggests the best deployment point is intermediate rather than "sparsify everything." A 12/18/20 layer coverage sweep is the next clean experiment.
L = 4096. Compute reduction is the attention scoring step, ≈ L / K.
These are FLOP estimates, not measured wall-clock — the FAISS path in this
repo is a research prototype that does CPU index builds and GPU↔CPU
transfers, so it is not the right thing to time. A GPU-resident topk
kernel is the natural next step.
| K | PPL gap | Attention scoring reduction |
|---|---|---|
| 512 | -5.66% (exact top-K over learned search space) | ~8x |
| 256 | -7.83% (exact top-K over learned search space) | ~16x |
| 128 | -9.36% exact; -1.81% FAISS/training eval | ~32x |
| 64 | +0.46% | ~64x |
| 32 | +0.03% | ~128x |
| 16 | +5.63% | ~256x |
Eval scope for the d_search table: 16 packed validation batches at 4K context
for PPL/recall during training, and 12 packed batches for compare_retrieval
mass@K. Numbers should be read as "what we observed on this slice", not
population-level estimates.
A few things the pilot does not yet establish, and that the next iteration will:
- Packing: the d_search ablation table is still from the packed leakage-confounded run and is best read as a capacity comparison. The clean block-causal d128 rerun removes cross-document leakage and should be used for quality claims.
- Exact-topK oracle: the obvious follow-up is a four-way Pareto —
full attention vs. exact top-K (true
QK^Targmax-K, then attention) vs. search-topK (our projections, exact distance) vs. search-ANN (FAISS HNSW). That separates "denoising from any sparsity" from "denoising from learned projections." - Wall-clock: the compute-reduction table above is FLOP-counted. The FAISS path here is a research prototype (CPU index per forward, GPU↔CPU transfer) and is the wrong thing to time. A GPU-resident topk kernel is the next-step engineering.
- Broad substitution: all-36 is viable but not parity (+3.23% best observed PPL gap). The all32 reserved-edge run reduces the cost to +1.746% in training eval and +0.590% at K=128 in the post-hoc exact sweep, with parity at K=256 on the small sweep.
The recall@K and mass@K reported here come from a 12-batch eval slice, not a population-level estimate. Confidence intervals and downstream tasks (LongBench / RULER / needle-in-haystack) are the natural next evals.
Two broad-layer configs are now wired:
all36_d128_block: trains all 36 layers, clean block-causal, d128. Best observed checkpoint is step 500 at +3.23% PPL gap.all32_d128_block: trains layers3..34, reserves[0, 1, 2, 35]as full attention. Final step-1000 checkpoint: recall@K=0.825 and +1.746% PPL gap in training eval; exact sweep reaches +0.590% at K=128 and parity at K=256 on a 2-batch clean slice.
Checkpoints are mirrored at
datasysdev/ann-sparseattention.
The closest prior work is RetrievalAttention (Liu et al., 2024). They show that vanilla ANN over the model's native Q and K vectors fails because Q and K live in mismatched distributions — they were never trained to be each other's nearest neighbors, only to score correctly via the dot product. Their fix is at index time: an attention-aware graph construction (RoarGraph-style) that compensates for the Q/K out-of- distribution problem at search time.
This work attacks the same problem from the opposite direction. Instead of
patching the index over hostile vectors, we train a tiny shared
low-dimensional projection (W_Qs, W_Ks → R^128 in the recommended pilot)
so that q_search and k_search do live in the same distribution by construction. Off-the-
shelf FAISS HNSW with default parameters is then sufficient — there is no
attention-aware index trick.
| Search space | Index | Trainable | |
|---|---|---|---|
| Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no — fails (Q/K OOD) |
| RetrievalAttention | original Q/K | attention-aware graph | no |
| This work | learned Q_s / K_s | off-the-shelf | yes (~2-11M params) |
The contribution claim: eliminate the Q/K mismatch at index-build time via distillation, instead of patching it at search time. The clean experiment to validate this — vanilla FAISS over raw Q/K vs. vanilla FAISS over learned Q_s/K_s vs. exact teacher top-K, all at the same K — is the next planned run. The current pilot establishes that the learned projections retrieve attention-relevant keys; the comparison run isolates how much of that came from the projection vs. the ANN approximation.
The paper now frames the method against two kinds of competitors:
- Closest asymptotic ancestor: Reformer. Reformer uses untrained LSH hashing to find candidate keys, then attends within buckets. This work keeps the retrieve-then-attend shape but trains the retrieval space from teacher attention, instead of relying on random hashes over native Q/K.
- Closest practical baseline: Quest. Quest is query-aware and strong at moderate context, but scans page metadata linearly in the number of pages. This work is weaker than Quest at K=128 PPL on the current slice, tied at K=256, but has a different long-context scaling target via ANN retrieval.
- Linear/smooth approximations: Performer. Performer changes the attention computation with random feature approximations to softmax. This work preserves exact softmax over the retrieved set and approximates only candidate selection.
- Fixed sparse patterns: Longformer/BigBird. These are efficient but not fully query-aware in the sense used here; remote keys are not selected because the current query needs them.
Qualitative property table from the paper:
| Method | Selection mechanism | Query-aware | Trained | Asymptotic | Exact softmax |
|---|---|---|---|---|---|
| Full attention | all keys | n/a | n/a | O(N²) | yes |
| Reformer | LSH hashing | yes | no | O(N log N) | over bucket |
| Performer | random features | n/a | no | O(N) | no |
| BigBird | window + random + global | mostly no | no | O(N) | over pattern |
| Longformer | sliding window + global | mostly no | no | O(N) | over pattern |
| NSA-style methods | block compression/selection | partial | partial | O(N²) proxy | yes |
| Quest | min/max page heuristic | yes | no | O(N) | over pages |
| This work | trained low-dim retrieval | yes | yes | O(N log N) | over retrieved set |
This is a design-positioning table, not a completed empirical win. The current clean results prove the row for the six-layer pilot and show that broad substitution becomes usable when weak edge layers are reserved. All-layer quality is not yet parity.
The method also targets a different deployment scenario than native sliding-window or state-space/hybrid architectures such as Mistral-style sliding window, Mamba, or Qwen3.6 Gated DeltaNet hybrids. Those approaches are trained from scratch with their sparse or hybrid mechanism in place. This work is post-hoc: train a base model with full attention for maximum expressivity, then add lightweight retrieval projections afterward to make inference sub-linear without changing base weights. The intended benefit is decoupling training-time architecture from inference-time architecture.
For each full-attention layer i we train two linear projections
W_Qs^i, W_Ks^i ∈ R^{d_model × d_search} (recommended pilot: d_search=128),
so that for any
hidden state h,
q_search = W_Qs^i h k_search = W_Ks^i h
softmax(q_search · k_search^T) ranks the same keys as
softmax(QK^T / √d_head) (the teacher's attention)
Two losses, summed across layers:
- InfoNCE with teacher-derived positives (top-
K_poskeys from the teacher's attention serve as positives for each query). - KL(teacher ‖ student) on the full attention distribution.
At inference, we monkey-patch each trained layer's attention forward to:
- Compute
q_search,k_searchfrom the same hidden state. - Build a per-batch FAISS HNSW index over
k_search(default params). - Retrieve top-
K_retrievepositions (causal-respecting) per query. - Run standard attention restricted to those
K_retrievekeys.
The base model's parameters are never touched. The recommended d128 pilot trains 3.93M parameters total.
config.py Run config (pilot defaults; make_headline_config() for follow-up)
model.py SearchProjection, FrozenForwardCapture (with QK reconstruction
trick: capture (Q, K) post-RoPE while the forward stays in FA),
contrastive + KL distillation losses
data.py Long-context dataloader (packing off by default to avoid
cross-segment attention leakage; pin_memory, prefetch)
inference.py ANN-substituted attention (exact top-K for analysis;
CPU-FAISS HNSW prototype path — not a deployable kernel)
eval.py recall@K curve, mass@K curve, full-vs-ANN PPL,
MoE router stability
train.py Training loop, Liger setup, FA-3→FA-2→SDPA→eager fallback,
base-model freeze + drift check, auto-resume from latest ckpt
tests/ QK reconstruction verification + 50-step smoke test
pip install -r requirements.txt
export WANDB_API_KEY=<key> # only — never check it in
export HF_TOKEN=<token> # for faster Hub downloads
# Pre-launch checks
python -c "from transformers import AutoConfig; \
print(AutoConfig.from_pretrained('Qwen/Qwen3-4B-Instruct-2507'))"
python tests/test_qk_reconstruction.py
python tests/smoke_test.py
# Packed d_search ablation
bash scripts/run_packed_ablation.sh
# Default clean pilot (packing off; data-sparse on WikiText articles)
python train.py --config pilot_d64_cleanThe default Config is the 1-day pilot:
| Knob | Pilot | Broad clean runs |
|---|---|---|
seq_len |
4096 | 8192 |
batch_size |
8 | 2 with grad accumulation 4 |
total_steps |
1000 | 1000 for current all36/all32 pilots |
| layers trained | 6 ([4,8,12,16,20,24]) |
all36 or all32 ([3..34]) |
| trainable params | 1.97M at d64; 3.93M at d128 | 23.59M all36 d128; ~20.97M all32 d128 |
d_search |
64 default; d128 recommended from ablation | d128 |
K_retrieve_eval |
128 | 128 |
Pilot is the proof-of-concept. The broad clean runs test whether the technique scales from a six-layer subset to dense application across most or all layers.
Use make_pilot_d128_packed_config() to reproduce the current recommended
packed capacity pilot, pilot_d128_block for the clean six-layer result,
all36_d128_block for all-layer substitution, or all32_d128_block for the
data-driven reserved-edge broad-substitution run.
attn_implementationresolves at load time asflash_attention_3 → flash_attention_2 → sdpa → eager. On B200 with no flash-attn package installed, SDPA wins — its built-in flash backend is ~80-90% of FA-2's throughput with zero build dependency.- Liger kernels applied via
apply_liger_kernel_to_qwen3(RMSNorm, SwiGLU, RoPE fused — typically 30-50% faster forward). - The QK-reconstruction trick keeps SDPA/FA fast on the trained layers:
we monkey-patch them to capture
(Q, K)post-RoPE, then reconstructsoftmax(QK^T/√d)ourselves after the forward returns. The forward never setsoutput_attentions=True(which would force eager). torch.compile(search_module, mode="max-autotune")on the search projections; base model uncompiled (works but flaky for novel architectures).- bf16 throughout; loss math cast to fp32 for numerical stability of softmax.
The post-RoPE Q/K capture must match what the model's eager attention computes or distillation supervision is wrong. The test asserts top-32 agreement
99% per layer:
python tests/test_qk_reconstruction.py --model Qwen/Qwen3-4B-Instruct-2507
# layer 0: PASS max|Δ|=2.54e-02 top-32 agree=0.9963
# layer 1: PASS max|Δ|=5.27e-02 top-32 agree=0.9941
# ...
# QK reconstruction verified.The bf16 max-abs differences (~0.05) are just numerical noise; the ranking of attention positions matches.
git clone git@github.com:unixsysdev/ann-sparseattention.git
cd ann-sparseattention
pip install -r requirements.txt
python train.py --config pilot_d128_packedA single H100/H200/B200 + 8GB GPU RAM for the 4B model + ~10GB for activations at 4K context, batch 8.
MIT.