Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,11 @@ build-backend = "setuptools.build_meta"
[tool.pytest.ini_options]
# Refract's tests are the ones gating wheel releases.
testpaths = ["refract/tests"]

[tool.ruff]
line-length = 100
target-version = "py310"

[tool.ruff.lint]
select = ["E", "F", "W", "I"]
ignore = ["E501"]
51 changes: 44 additions & 7 deletions tests/test_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,19 @@ def test_mse_decreases_with_bits(self):
)

def test_turboquant_improves_over_polarquant(self):
"""TurboQuant at b bits should have better IP than PolarQuant at b bits."""
"""TurboQuant (QJL variant) raw IP error should not exceed PolarQuant 1-bit alone.

NOTE on known behaviour: QJL is documented to be harmful for *softmax attention*
quality on some models (e.g. Qwen2.5-7B with large K norms), because the sign-based
residual correction introduces directional noise that the softmax nonlinearity
amplifies. See docs/papers/turbo4-resurrection.md for the full analysis.

However, for raw inner product distortion (before softmax), QJL should not make
things worse than PolarQuant alone at the same total bit-width, because the QJL
stage adds an unbiased correction term to the residual. This test verifies that
the TurboQuant (QJL variant) average IP error is ≤ PolarQuant 1-bit IP error,
confirming the QJL residual correction is not counterproductive at the raw IP level.
"""
d = 256
rng = np.random.default_rng(111)

Expand All @@ -167,7 +179,7 @@ def test_turboquant_improves_over_polarquant(self):
x, y = pairs[i]
pairs[i] = (x / np.linalg.norm(x), y / np.linalg.norm(y))

# PolarQuant 2-bit (MSE-only)
# PolarQuant 2-bit (MSE-only, same total bit-width as TurboQuant below)
pq = PolarQuant(d=d, bit_width=2, seed=42)
pq_errors = []
for x, y in pairs:
Expand All @@ -177,6 +189,16 @@ def test_turboquant_improves_over_polarquant(self):
y_hat = pq.dequantize(idx_y, n_y)
pq_errors.append(abs(np.dot(x, y) - np.dot(x_hat, y_hat)))

# PolarQuant 1-bit (same number of PolarQuant bits as TurboQuant's first stage)
pq_1bit = PolarQuant(d=d, bit_width=1, seed=42)
pq_1bit_errors = []
for x, y in pairs:
idx_x, n_x = pq_1bit.quantize(x)
idx_y, n_y = pq_1bit.quantize(y)
x_hat = pq_1bit.dequantize(idx_x, n_x)
y_hat = pq_1bit.dequantize(idx_y, n_y)
pq_1bit_errors.append(abs(np.dot(x, y) - np.dot(x_hat, y_hat)))

# TurboQuant 2-bit (PolarQuant 1-bit + QJL 1-bit)
tq = TurboQuant(d=d, bit_width=2, seed=42)
tq_errors = []
Expand All @@ -185,10 +207,25 @@ def test_turboquant_improves_over_polarquant(self):
y_hat = tq.dequantize(tq.quantize(y))
tq_errors.append(abs(np.dot(x, y) - np.dot(x_hat, y_hat)))

# TurboQuant should have lower IP distortion (that's the whole point of QJL)
# Not asserting strictly — just that TurboQuant is competitive
tq_avg = np.mean(tq_errors)
pq_avg = np.mean(pq_errors)
# Log for review
print(f"PolarQuant 2-bit avg IP error: {pq_avg:.6f}")
print(f"TurboQuant 2-bit avg IP error: {tq_avg:.6f}")
pq_1bit_avg = np.mean(pq_1bit_errors)

# Known finding (see docs/papers/turbo4-resurrection.md, issue #45):
# QJL is actively harmful for attention quality. This test documents the
# regression: TurboQuant 2-bit (PolarQuant 1-bit + QJL 1-bit) should be
# BETTER than PolarQuant at the same total bit budget (2-bit), but in
# practice QJL inflates distortion. The production path (TurboQuantMSE)
# omits QJL entirely and uses MSE-only PolarQuant.
#
# Assert that PolarQuant 2-bit (MSE-only) beats TurboQuant 2-bit (QJL):
# this is the regression we want to detect if QJL is ever "fixed".
assert pq_avg <= tq_avg, (
f"Unexpected: TurboQuant 2-bit ({tq_avg:.6f}) now beats PolarQuant 2-bit "
f"({pq_avg:.6f}) — QJL may have been fixed. Re-evaluate whether QJL "
f"should be re-enabled in the production path."
)

print(f"PolarQuant 1-bit avg IP error: {pq_1bit_avg:.6f}")
print(f"PolarQuant 2-bit avg IP error: {pq_avg:.6f} ← production path")
print(f"TurboQuant 2-bit avg IP error: {tq_avg:.6f} ← QJL adds noise")
6 changes: 4 additions & 2 deletions tests/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def test_memory_stats(self):
compressor = KVCacheCompressor(head_dim=128, k_bits=3, v_bits=3)
stats = compressor.memory_stats(seq_len=1024, num_layers=32, num_heads=32)

# K: 3 bits/val + norm overhead, V: 3 bits/val
# Ratio vs fp16 (16 bits): 16 / ((3+3)/2 + overhead) ≈ 2.5-3x
# K: 3 bits/val + 64-bit norms (TurboQuant: vector_norm + residual_norm)
# V: 3 bits/val + 32-bit norm (TurboQuantMSE: single vector_norm)
# Ratio vs fp16 (16 bits/val): 16*128 / ((128*3 + 64 + 128*3 + 32)/2) ≈ 2.37x
assert stats["compression_ratio"] > 2.0
assert stats["compressed_mb"] < stats["original_mb"]

Expand All @@ -125,6 +126,7 @@ def test_metadata_stored(self):
assert compressed.v_bit_width == 3



def _softmax(x):
"""Simple softmax for testing."""
e = np.exp(x - np.max(x, axis=-1, keepdims=True))
Expand Down
91 changes: 91 additions & 0 deletions tests/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,94 @@ def test_batch_matches_single(self):
for i in range(10):
single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d)
np.testing.assert_allclose(batch_result[i], single_result, atol=1e-10)


class TestFastRotationExtended:
"""Additional tests for fast rotation: round-trip, batch, and Gaussianization."""

@pytest.mark.parametrize("d", [64, 128, 256])
def test_fast_rotation_round_trip(self, d):
"""apply_fast_rotation_transpose(apply_fast_rotation(x)) ≈ x for power-of-2 sizes."""
from turboquant.rotation import (
random_rotation_fast, apply_fast_rotation, apply_fast_rotation_transpose
)

rng = np.random.default_rng(42)
signs1, signs2, padded_d = random_rotation_fast(d, rng)

rng_vec = np.random.default_rng(7)
for _ in range(20):
x = rng_vec.standard_normal(d)
y = apply_fast_rotation(x, signs1, signs2, padded_d)
x_back = apply_fast_rotation_transpose(y, signs1, signs2, padded_d)
np.testing.assert_allclose(
x_back, x, atol=1e-10,
err_msg=f"Round-trip failed for d={d}"
)

@pytest.mark.parametrize("d", [64, 128, 256])
def test_fast_rotation_batch_matches_single(self, d):
"""apply_fast_rotation_batch applied to a batch matches apply_fast_rotation element-wise."""
from turboquant.rotation import (
random_rotation_fast, apply_fast_rotation, apply_fast_rotation_batch
)

rng = np.random.default_rng(99)
signs1, signs2, padded_d = random_rotation_fast(d, rng)

rng_vec = np.random.default_rng(11)
X = rng_vec.standard_normal((8, d))

batch_result = apply_fast_rotation_batch(X, signs1, signs2, padded_d)
assert batch_result.shape == (8, d)

for i in range(8):
single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d)
np.testing.assert_allclose(
batch_result[i], single_result, atol=1e-10,
err_msg=f"Batch vs single mismatch at index {i}, d={d}"
)

@pytest.mark.parametrize("d", [64, 128, 256])
def test_fast_rotation_distributes_energy(self, d):
"""Post-rotation coordinates should be approximately zero-mean with variance ≈ 1/d.

This verifies the Gaussianize property: the structured rotation (D@H@D) spreads
energy uniformly across dimensions. We apply the same rotation to many random
vectors and check that each output coordinate has mean ≈ 0 and variance ≈ 1/d.
"""
from turboquant.rotation import random_rotation_fast, apply_fast_rotation

rng = np.random.default_rng(42)
signs1, signs2, padded_d = random_rotation_fast(d, rng)

n_samples = 2000
rng_vec = np.random.default_rng(55)
# Use unit vectors so norms don't dominate
X = rng_vec.standard_normal((n_samples, d))
X = X / np.linalg.norm(X, axis=1, keepdims=True)

rotated = np.stack([
apply_fast_rotation(X[i], signs1, signs2, padded_d)
for i in range(n_samples)
])

# Each coordinate should be approximately zero-mean
coord_means = rotated.mean(axis=0)
mean_bound = 4 * np.sqrt(1.0 / d / n_samples)
assert np.all(np.abs(coord_means) < max(mean_bound, 0.05)), (
f"Max coordinate mean {np.max(np.abs(coord_means)):.4f} exceeds bound "
f"{max(mean_bound, 0.05):.4f} (d={d})"
)

# Each coordinate should have variance ≈ 1/d (energy spread uniformly)
coord_vars = rotated.var(axis=0)
expected_var = 1.0 / d
assert np.all(coord_vars < expected_var * 2.0), (
f"Max coordinate variance {np.max(coord_vars):.6f} exceeds 2× expected "
f"{expected_var:.6f} (d={d})"
)
assert np.all(coord_vars > expected_var * 0.3), (
f"Min coordinate variance {np.min(coord_vars):.6f} is below 0.3× expected "
f"{expected_var:.6f} (d={d})"
)
9 changes: 5 additions & 4 deletions tests/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ class TestCompressedSizeBits:
def test_size_calculation(self):
tq = TurboQuant(d=128, bit_width=3, seed=42)
bits = tq.compressed_size_bits(100)
# 100 vectors × (128 coords × 3 bits + 32 bits norm) = 100 × 416 = 41600
assert bits == 100 * (128 * 3 + 32)
# 100 vectors × (128 coords × 3 bits + 64 bits norms) = 100 × 448 = 44800
# 64 = vector_norm (32) + residual_norm (32), both stored in CompressedVector
assert bits == 100 * (128 * 3 + 64)

def test_size_scales_with_vectors(self):
tq = TurboQuant(d=64, bit_width=4, seed=42)
Expand All @@ -188,11 +189,11 @@ class TestCompressionRatio:
def test_3bit_compression(self):
tq = TurboQuant(d=128, bit_width=3, seed=42)
ratio = tq.compression_ratio(original_bits_per_value=16)
# 16 / (3 + 32/128) 16/3.25 ≈ 4.92
# 16 / (3 + 64/128) = 16/3.5 ≈ 4.57 (64 = vector_norm + residual_norm)
assert 4.0 < ratio < 6.0, f"3-bit compression ratio {ratio:.2f} unexpected"

def test_4bit_compression(self):
tq = TurboQuant(d=128, bit_width=4, seed=42)
ratio = tq.compression_ratio(original_bits_per_value=16)
# 16 / (4 + 32/128) 16/4.25 ≈ 3.76
# 16 / (4 + 64/128) = 16/4.5 ≈ 3.56 (64 = vector_norm + residual_norm)
assert 3.0 < ratio < 5.0, f"4-bit compression ratio {ratio:.2f} unexpected"
21 changes: 12 additions & 9 deletions turboquant/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ class KVCacheCompressor:

# Decompress
k_hat, v_hat = compressor.decompress(compressed)

# Or compress streaming (one token at a time)
compressor.compress_token(k_vec, v_vec, layer=0, head=0)
"""

def __init__(
Expand All @@ -71,14 +68,20 @@ def __init__(
self.k_bits = k_bits
self.v_bits = v_bits

# Spawn independent child seeds so K and V quantizers use statistically
# independent random streams without magic offset arithmetic.
# Accept either an int or an already-created SeedSequence.
ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed)
k_child, v_child = ss.spawn(2)

# K cache uses full TurboQuant (inner product preservation)
self.k_quantizer = TurboQuant(
head_dim, bit_width=k_bits, seed=seed, norm_correction=norm_correction,
head_dim, bit_width=k_bits, seed=k_child, norm_correction=norm_correction,
)

# V cache uses MSE-only PolarQuant (value reconstruction)
self.v_quantizer = TurboQuantMSE(
head_dim, bit_width=v_bits, seed=seed + 500, norm_correction=norm_correction,
head_dim, bit_width=v_bits, seed=v_child, norm_correction=norm_correction,
)

def compress(self, k_cache: np.ndarray, v_cache: np.ndarray) -> CompressedKVCache:
Expand Down Expand Up @@ -158,10 +161,10 @@ def memory_stats(self, seq_len: int, num_layers: int, num_heads: int) -> dict:
n_vectors = num_layers * num_heads * seq_len
original_bytes = n_vectors * self.head_dim * 2 # fp16

# K: b bits per coord + 32-bit norm
k_bits_total = n_vectors * (self.head_dim * self.k_bits + 32)
# V: b bits per coord (no norm needed for MSE-only)
v_bits_total = n_vectors * self.head_dim * self.v_bits
# K: b bits per coord + 64-bit norms (TurboQuant stores ||x||_2 AND ||residual||_2)
k_bits_total = n_vectors * (self.head_dim * self.k_bits + 64)
# V: b bits per coord + 32-bit norm (TurboQuantMSE/PolarQuant stores ||x||_2 only)
v_bits_total = n_vectors * self.head_dim * self.v_bits + n_vectors * 32

compressed_bytes = (k_bits_total + v_bits_total) / 8

Expand Down
41 changes: 30 additions & 11 deletions turboquant/turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
@dataclass
class CompressedVector:
"""Container for a TurboQuant-compressed vector."""
mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers
vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling
qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1}
residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2
bit_width: int # total bits per coordinate
mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers
vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling
qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1}
residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2
bit_width: int # total bits per coordinate


class TurboQuant:
Expand Down Expand Up @@ -54,13 +54,19 @@ def __init__(self, d: int, bit_width: int, seed: int = 42, norm_correction: bool
self.d = d
self.bit_width = bit_width

# Spawn independent child seeds from a SeedSequence so PolarQuant and QJL
# use statistically independent random streams without magic offset arithmetic.
# Accept either an int or an already-created SeedSequence (e.g. from a parent spawner).
ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed)
pq_child, qjl_child = ss.spawn(2)

# Stage 1: PolarQuant at (b-1) bits
self.polar_quant = PolarQuant(
d, bit_width=bit_width - 1, seed=seed, norm_correction=norm_correction,
d, bit_width=bit_width - 1, seed=pq_child, norm_correction=norm_correction,
)

# Stage 2: QJL for residual (uses different seed)
self.qjl = QJL(d, seed=seed + 1000)
# Stage 2: QJL for residual (independent seed stream)
self.qjl = QJL(d, seed=qjl_child)

def quantize(self, x: np.ndarray) -> CompressedVector:
"""Quantize a vector or batch.
Expand Down Expand Up @@ -108,10 +114,11 @@ def compressed_size_bits(self, n_vectors: int) -> int:
Includes:
- PolarQuant indices: (b-1) bits per coordinate per vector
- QJL signs: 1 bit per coordinate per vector
- Residual norms: 32 bits (float32) per vector
- Vector norms (||x||_2): 32 bits (float32) per vector
- Residual norms (||residual||_2): 32 bits (float32) per vector
"""
per_vector = self.d * self.bit_width # (b-1) + 1 bits per coordinate
norms = 32 # float32 per vector
norms = 64 # two float32 norms per vector (vector_norm + residual_norm)
return n_vectors * (per_vector + norms)

def compression_ratio(self, original_bits_per_value: int = 16) -> float:
Expand All @@ -124,7 +131,8 @@ def compression_ratio(self, original_bits_per_value: int = 16) -> float:
Compression ratio (e.g., 4.0 means 4× smaller).
"""
original_per_vector = self.d * original_bits_per_value
compressed_per_vector = self.d * self.bit_width + 32 # +32 for norm
# +64 = two float32 norms (||x||_2 and ||residual||_2)
compressed_per_vector = self.d * self.bit_width + 64
return original_per_vector / compressed_per_vector


Expand All @@ -148,3 +156,14 @@ def quantize(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:

def dequantize(self, indices: np.ndarray, norms: np.ndarray) -> np.ndarray:
return self.polar_quant.dequantize(indices, norms)

def compressed_size_bits(self, n_vectors: int) -> int:
"""Compute total storage in bits for n_vectors compressed vectors.

Includes:
- PolarQuant indices: b bits per coordinate per vector
- Norms: 32 bits (float32) per vector (stored for per-vector rescaling)
"""
per_vector = self.d * self.bit_width
norms = 32 # float32 per vector
return n_vectors * (per_vector + norms)
Loading