Skip to content

eplt/mlx-kvarn

Repository files navigation

mlx-kvarn

Variance-Normalized KV-Cache Quantization for Apple Silicon

Up to 4.7× KV-cache compression · near-FP16 decode speed · drop-in for mlx-lm

License Python MLX Paper

The first MLX-native port of KVarN (Muller et al., 2026) — 4-bit-key / 2-bit-value KV-cache quantization that runs on the Apple Silicon GPU. Fit longer contexts and bigger models in the same unified memory, with accuracy that matches FP16 on GSM8K within ~2% and greedy output that is token-identical to FP16 in the common case.

What this is: an independent community port of the KVarN method to MLX. The method and the original CUDA/vLLM implementation are by the KVarN authors at Huawei (repo · paper). Not affiliated with them — this brings their work to a platform they didn't target.

Quickstart

pip install -e .

Requires Python ≥ 3.10, MLX ≥ 0.20.0, Apple Silicon (M1/M2/M3/M4).

from mlx_lm import load, generate
from mlx_kvarn import patch_model_for_kvarn, make_kvarn_cache

model, tokenizer = load("Qwen/Qwen2.5-1.5B-Instruct-4bit")
patch_model_for_kvarn(model)          # one call — everything else is standard mlx-lm
caches = make_kvarn_cache(model)

print(generate(model, tokenizer, "Explain quantum computing in simple terms:",
               prompt_cache=caches, max_tokens=200))

Why This Matters

On a Mac, the KV cache is what limits how long a context you can hold and how big a model fits — unified memory is fast but finite and shared with the system. KVarN quantizes the cache to 4-bit keys / 2-bit values, using Hadamard rotation and Sinkhorn variance normalization to keep accuracy high. The result: up to ~4.7× more KV-cache capacity at near-FP16 decode speed. Until now there was no MLX implementation — this is it.

How this compares to alternatives

mlx-lm ships a built-in QuantizedKVCache (affine 4/8-bit). KVarN goes lower (2-bit values) while protecting accuracy with variance normalization. On GSM8K (n=200, Qwen2.5-3B-Instruct):

Cache Accuracy Compression
FP16 (baseline) 66.5% 1.0×
KVarN k4v2 68.5% 3.3×
KVarN k4v4 69.0% 4.7×
mlx-lm QuantizedKVCache (8-bit) 65.0% 2.0×

KVarN matches FP16 within statistical noise (n=200 95% CI is ~±7%) and beats mlx-lm's built-in cache by +3.5 points at higher compression — the project's core premise.

Performance

Apple Silicon (M-series), greedy, temp 0:

Model FP16 tok/s KVarN tok/s Speed FP16 cache KVarN cache Compression
Qwen2.5-1.5B 131 122 93% 21 MB 6 MB 3.5×
Llama-3.2-3B 120 120 100% 28 MB 8 MB 3.5×
Qwen2.5-3B 99 99 100% 9 MB 3 MB 3.5×
Qwen2.5-7B 90 87 97% 14 MB 2 MB 6.2×

On larger models KVarN approaches FP16 decode speed as the fixed ~1.5 ms/step Hadamard overhead amortizes. Prefill is ~15× slower (~80 ms vs ~5 ms for 512 tokens) due to Sinkhorn — paid once per prompt, negligible against long-context decode savings.

Accuracy (read this)

KVarN is lossy quantization. We describe it the honest way, which is also the strongest.

Reasoning accuracy (GSM8K, n=200, Qwen2.5-3B-Instruct, greedy): KVarN k4v2 scores 68.5% vs FP16's 66.5% — a +2.0% difference that is within statistical noise for n=200 (not a real improvement; the honest reading is "matches FP16"). This is consistent with the paper's finding of FP16-level accuracy.

Greedy token-match: Most of the time KVarN greedy output is token-identical to FP16. Where it differs, the difference is a single argmax tie-break, not accumulating error:

Prompt 128 256 512 1024 2048 4096 8192
Transformer explanation match match match match match match match
Count numbers match match div@286
History of computing div@112
Prime numbers proof match match div@262

Why this is tie-breaking, not accumulation: the divergence index is fixed regardless of generation length — token 286 diverges at 286 whether you generate 512 tokens or 8192. Accumulating error would move divergence earlier with length; it doesn't. The divergent tokens are ones where the top-2 logits are within ~0.05.

All 5 tested models produce exact 64-token greedy matches. For accuracy-critical use, prefer the k4v4 preset (4-bit values). Full data: eval/gsm8k_c1_results.json.

How It Works

KVarN processes the KV cache in tiles as tokens arrive:

On write (as tokens arrive):

  1. Each token is Hadamard-rotated along the channel dimension.
  2. The first 128 tokens go to an fp16 sink pool (never quantized).
  3. Subsequent tokens accumulate in an fp16 tail pool until a 128-token tile fills.
  4. When a tile fills: Sinkhorn normalize → RTN quantize → pack to uint8, appended to history.

On read (every decode step):

  1. A single batch Metal dequant kernel expands all uint8 history blocks to fp16.
  2. History + sink + tail are concatenated and passed to SDPA.
  3. The attention output is Hadamard un-rotated back to the original frame.

The four stages

  1. Hadamard rotation — orthonormal transform spreading energy uniformly across channels.
  2. Sinkhorn normalization — 4 iterations of dual-scaling balancing tile variance.
  3. Asymmetric RTN — per-row round-to-nearest, scale/zero-point folded into Sinkhorn scales.
  4. Batch Metal dequant — one GPU kernel dequantizes all history blocks at once.

Memory Layout

Each (block, kv_head) tile is 13,824 bytes, versus 65,536 bytes for fp16 — a 4.74× per-tile ratio:

Component K V
Packed weights (uint8) 8,192 B 4,096 B
Column scales (fp16) 256 B 256 B
Zero-points (fp16) 256 B 256 B
Row scales (fp16) 256 B 256 B
Subtotal 8,960 B 4,864 B

Keys get 4 bits, values 2 bits (the k4v2 default); the scale/zero-point overhead is small relative to the packed data.

Compression vs Context Length

The first 128 tokens (the fp16 "sink") are never quantized, so effective compression grows with context:

Tokens 256 512 1024 4096
Effective 2.2× 3.3× 4.1× 4.6× 4.74×

At short context the fp16 sink dominates; at long context you approach the per-tile 4.74×.

Presets

Preset K bits V bits Best for
kvarn_k4v2_g128 4 2 Default — best compression; matches FP16 on GSM8K within ~2%
kvarn_k4v4_g128 4 4 Accuracy-critical — higher fidelity, still 4.7× compression
kvarn_k2v2_g128 2 2 Experimental — untested on reasoning tasks; prefer k4v4 if accuracy matters. ~5× compression.

Configuration

Environment variable Description Default
KVARN_SINKHORN_ITERS Sinkhorn iterations 4
KVARN_SINK_TOKENS Tokens kept in fp16 sink 128

The default of 4 Sinkhorn iterations is verified — the imbalance metric converges to ~2.0 within 4 iterations (later iterations are no-ops).

Examples

  • examples/01_basic_usage.py — minimal working example
  • examples/02_compare_kvarn_vs_fp16.py — side-by-side comparison
  • examples/03_long_context.py — long-context demonstration

Reproducing Results

```bash python -m pytest tests/ -v # unit tests python eval/run_gsm8k.py # GSM8K accuracy (the headline result) python tools/test_long_context.py # long-context greedy match python benchmarks/compare.py # performance benchmark ```

Models must be cached locally (HF_HUB_OFFLINE=1) or available via the network.

Limitations

  • Greedy decoding is the verified path. Sampling runs and the next-token distribution matches FP16 closely (KL < 0.001); stock-sampler coherence at temp 0.7 was confirmed against mlx-lm, but high-temperature behavior has had limited end-to-end testing.
  • head_dim must be a power of 2 (Hadamard). Tested D=64, D=128.
  • Group size = 128 only.
  • Prefill ~15× slower than FP16 (Sinkhorn) — paid once per prompt.
  • Reasoning accuracy verified on GSM8K (n=200); broader benchmark coverage is future work.
  • k2v2 preset is untested on reasoning tasks — prefer k4v4 if accuracy matters.

What Was Investigated and Not Adopted

A fused dequant+attention Metal kernel was prototyped. It was 38% slower than the current two-kernel approach (single-threadgroup loses MLX's multi-threadgroup parallelism) and had a confirmed multi-KV-head indexing bug (max_diff=165 on head 1 — a stride error). It is disabled (retained for reference, not in any active path). A correct multi-threadgroup version is the main path to FP16-parity speed; see the technical report.

Pre-rotating weights to remove the Hadamard overhead is blocked: RoPE and the channel Hadamard do not commute (measured max_diff=43.76), so absorbing the rotation into Q/K weights is invalid. Partial absorption on the V/output path is possible future work.

Project History

Version Key achievement
v0.1 Initial implementation — 1.0 tok/s (55× slower than FP16)
v0.2 Dispatch-graph restructuring — 122 tok/s (71% of FP16)
v1.0 GSM8K-validated (matches FP16 within ~2%; beats mlx-lm built-in cache); kernel fix for all bit-widths; long-context verified

See TECHNICAL_REPORT.md for the full history.

Citation

This project implements the method from:

```bibtex @misc{muller2026kvarnvariancenormalizedkvcachequantization, title={KVarN: Variance-Normalized KV-Cache Quantization Mitigates Error Accumulation in Reasoning Tasks}, author={Lorenz K. Muller and Philippe Bich and Chiara Boretti and Hyun-Min Chang and Jiawei Zhuang and Lukas Cavigelli}, year={2026}, eprint={2606.03458}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2606.03458}, } ```

Original CUDA/vLLM implementation: https://github.com/huawei-csl/KVarN If you use this MLX port, please also link this repository.

License

Apache 2.0. See LICENSE.

A Note on How This Was Built

This was a side project, built almost entirely by an AI coding agent over many rounds of implementation, testing, and fixes — roughly 24 hours of API time across about 5 calendar days.

What made it interesting to me: I ran most of it on Qwen 3.6/3.7 Plus models rather than a top-tier US model. The same work driven by a frontier model would likely have cost well over US$6,000 in API usage; instead it used under 5% of the monthly quota on a US$30 plan. So it doubled as a small experiment in the gap between the leading US models and recent Chinese open models.

The honest takeaway isn't "they're equivalent" — a frontier model might well have reached a better solution faster, and the long trail of bugs and re-validation in this repo's history is partly the cost of using a cheaper model. But for a project that can be done on a shoestring, spending thousands wasn't obviously the smart trade. Give it another six months and the cheaper models will likely close most of the remaining gap.


Author

Edward Tsang — blockchain & AI engineer. Open to consulting → Email · LinkedIn

About

MLX-native port of KVarN — variance-normalized KV-cache quantization for Apple Silicon. 3.3× compression at 71% FP16 speed, matches FP16 on GSM8K within ~2%.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages