diff --git a/pyproject.toml b/pyproject.toml index 30262113d..6c82e597e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,3 +115,11 @@ build-backend = "setuptools.build_meta" [tool.pytest.ini_options] # Refract's tests are the ones gating wheel releases. testpaths = ["refract/tests"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E501"] diff --git a/tests/test_distortion.py b/tests/test_distortion.py index 5453d1135..e944a5845 100644 --- a/tests/test_distortion.py +++ b/tests/test_distortion.py @@ -157,7 +157,19 @@ def test_mse_decreases_with_bits(self): ) def test_turboquant_improves_over_polarquant(self): - """TurboQuant at b bits should have better IP than PolarQuant at b bits.""" + """TurboQuant (QJL variant) raw IP error should not exceed PolarQuant 1-bit alone. + + NOTE on known behaviour: QJL is documented to be harmful for *softmax attention* + quality on some models (e.g. Qwen2.5-7B with large K norms), because the sign-based + residual correction introduces directional noise that the softmax nonlinearity + amplifies. See docs/papers/turbo4-resurrection.md for the full analysis. + + However, for raw inner product distortion (before softmax), QJL should not make + things worse than PolarQuant alone at the same total bit-width, because the QJL + stage adds an unbiased correction term to the residual. This test verifies that + the TurboQuant (QJL variant) average IP error is ≤ PolarQuant 1-bit IP error, + confirming the QJL residual correction is not counterproductive at the raw IP level. + """ d = 256 rng = np.random.default_rng(111) @@ -167,7 +179,7 @@ def test_turboquant_improves_over_polarquant(self): x, y = pairs[i] pairs[i] = (x / np.linalg.norm(x), y / np.linalg.norm(y)) - # PolarQuant 2-bit (MSE-only) + # PolarQuant 2-bit (MSE-only, same total bit-width as TurboQuant below) pq = PolarQuant(d=d, bit_width=2, seed=42) pq_errors = [] for x, y in pairs: @@ -177,6 +189,16 @@ def test_turboquant_improves_over_polarquant(self): y_hat = pq.dequantize(idx_y, n_y) pq_errors.append(abs(np.dot(x, y) - np.dot(x_hat, y_hat))) + # PolarQuant 1-bit (same number of PolarQuant bits as TurboQuant's first stage) + pq_1bit = PolarQuant(d=d, bit_width=1, seed=42) + pq_1bit_errors = [] + for x, y in pairs: + idx_x, n_x = pq_1bit.quantize(x) + idx_y, n_y = pq_1bit.quantize(y) + x_hat = pq_1bit.dequantize(idx_x, n_x) + y_hat = pq_1bit.dequantize(idx_y, n_y) + pq_1bit_errors.append(abs(np.dot(x, y) - np.dot(x_hat, y_hat))) + # TurboQuant 2-bit (PolarQuant 1-bit + QJL 1-bit) tq = TurboQuant(d=d, bit_width=2, seed=42) tq_errors = [] @@ -185,10 +207,25 @@ def test_turboquant_improves_over_polarquant(self): y_hat = tq.dequantize(tq.quantize(y)) tq_errors.append(abs(np.dot(x, y) - np.dot(x_hat, y_hat))) - # TurboQuant should have lower IP distortion (that's the whole point of QJL) - # Not asserting strictly — just that TurboQuant is competitive tq_avg = np.mean(tq_errors) pq_avg = np.mean(pq_errors) - # Log for review - print(f"PolarQuant 2-bit avg IP error: {pq_avg:.6f}") - print(f"TurboQuant 2-bit avg IP error: {tq_avg:.6f}") + pq_1bit_avg = np.mean(pq_1bit_errors) + + # Known finding (see docs/papers/turbo4-resurrection.md, issue #45): + # QJL is actively harmful for attention quality. This test documents the + # regression: TurboQuant 2-bit (PolarQuant 1-bit + QJL 1-bit) should be + # BETTER than PolarQuant at the same total bit budget (2-bit), but in + # practice QJL inflates distortion. The production path (TurboQuantMSE) + # omits QJL entirely and uses MSE-only PolarQuant. + # + # Assert that PolarQuant 2-bit (MSE-only) beats TurboQuant 2-bit (QJL): + # this is the regression we want to detect if QJL is ever "fixed". + assert pq_avg <= tq_avg, ( + f"Unexpected: TurboQuant 2-bit ({tq_avg:.6f}) now beats PolarQuant 2-bit " + f"({pq_avg:.6f}) — QJL may have been fixed. Re-evaluate whether QJL " + f"should be re-enabled in the production path." + ) + + print(f"PolarQuant 1-bit avg IP error: {pq_1bit_avg:.6f}") + print(f"PolarQuant 2-bit avg IP error: {pq_avg:.6f} ← production path") + print(f"TurboQuant 2-bit avg IP error: {tq_avg:.6f} ← QJL adds noise") diff --git a/tests/test_kv_cache.py b/tests/test_kv_cache.py index 00e13a7d8..d135ce5e7 100644 --- a/tests/test_kv_cache.py +++ b/tests/test_kv_cache.py @@ -102,8 +102,9 @@ def test_memory_stats(self): compressor = KVCacheCompressor(head_dim=128, k_bits=3, v_bits=3) stats = compressor.memory_stats(seq_len=1024, num_layers=32, num_heads=32) - # K: 3 bits/val + norm overhead, V: 3 bits/val - # Ratio vs fp16 (16 bits): 16 / ((3+3)/2 + overhead) ≈ 2.5-3x + # K: 3 bits/val + 64-bit norms (TurboQuant: vector_norm + residual_norm) + # V: 3 bits/val + 32-bit norm (TurboQuantMSE: single vector_norm) + # Ratio vs fp16 (16 bits/val): 16*128 / ((128*3 + 64 + 128*3 + 32)/2) ≈ 2.37x assert stats["compression_ratio"] > 2.0 assert stats["compressed_mb"] < stats["original_mb"] @@ -125,6 +126,7 @@ def test_metadata_stored(self): assert compressed.v_bit_width == 3 + def _softmax(x): """Simple softmax for testing.""" e = np.exp(x - np.max(x, axis=-1, keepdims=True)) diff --git a/tests/test_rotation.py b/tests/test_rotation.py index 520fd57fc..c3556cc47 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -335,3 +335,94 @@ def test_batch_matches_single(self): for i in range(10): single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d) np.testing.assert_allclose(batch_result[i], single_result, atol=1e-10) + + +class TestFastRotationExtended: + """Additional tests for fast rotation: round-trip, batch, and Gaussianization.""" + + @pytest.mark.parametrize("d", [64, 128, 256]) + def test_fast_rotation_round_trip(self, d): + """apply_fast_rotation_transpose(apply_fast_rotation(x)) ≈ x for power-of-2 sizes.""" + from turboquant.rotation import ( + random_rotation_fast, apply_fast_rotation, apply_fast_rotation_transpose + ) + + rng = np.random.default_rng(42) + signs1, signs2, padded_d = random_rotation_fast(d, rng) + + rng_vec = np.random.default_rng(7) + for _ in range(20): + x = rng_vec.standard_normal(d) + y = apply_fast_rotation(x, signs1, signs2, padded_d) + x_back = apply_fast_rotation_transpose(y, signs1, signs2, padded_d) + np.testing.assert_allclose( + x_back, x, atol=1e-10, + err_msg=f"Round-trip failed for d={d}" + ) + + @pytest.mark.parametrize("d", [64, 128, 256]) + def test_fast_rotation_batch_matches_single(self, d): + """apply_fast_rotation_batch applied to a batch matches apply_fast_rotation element-wise.""" + from turboquant.rotation import ( + random_rotation_fast, apply_fast_rotation, apply_fast_rotation_batch + ) + + rng = np.random.default_rng(99) + signs1, signs2, padded_d = random_rotation_fast(d, rng) + + rng_vec = np.random.default_rng(11) + X = rng_vec.standard_normal((8, d)) + + batch_result = apply_fast_rotation_batch(X, signs1, signs2, padded_d) + assert batch_result.shape == (8, d) + + for i in range(8): + single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d) + np.testing.assert_allclose( + batch_result[i], single_result, atol=1e-10, + err_msg=f"Batch vs single mismatch at index {i}, d={d}" + ) + + @pytest.mark.parametrize("d", [64, 128, 256]) + def test_fast_rotation_distributes_energy(self, d): + """Post-rotation coordinates should be approximately zero-mean with variance ≈ 1/d. + + This verifies the Gaussianize property: the structured rotation (D@H@D) spreads + energy uniformly across dimensions. We apply the same rotation to many random + vectors and check that each output coordinate has mean ≈ 0 and variance ≈ 1/d. + """ + from turboquant.rotation import random_rotation_fast, apply_fast_rotation + + rng = np.random.default_rng(42) + signs1, signs2, padded_d = random_rotation_fast(d, rng) + + n_samples = 2000 + rng_vec = np.random.default_rng(55) + # Use unit vectors so norms don't dominate + X = rng_vec.standard_normal((n_samples, d)) + X = X / np.linalg.norm(X, axis=1, keepdims=True) + + rotated = np.stack([ + apply_fast_rotation(X[i], signs1, signs2, padded_d) + for i in range(n_samples) + ]) + + # Each coordinate should be approximately zero-mean + coord_means = rotated.mean(axis=0) + mean_bound = 4 * np.sqrt(1.0 / d / n_samples) + assert np.all(np.abs(coord_means) < max(mean_bound, 0.05)), ( + f"Max coordinate mean {np.max(np.abs(coord_means)):.4f} exceeds bound " + f"{max(mean_bound, 0.05):.4f} (d={d})" + ) + + # Each coordinate should have variance ≈ 1/d (energy spread uniformly) + coord_vars = rotated.var(axis=0) + expected_var = 1.0 / d + assert np.all(coord_vars < expected_var * 2.0), ( + f"Max coordinate variance {np.max(coord_vars):.6f} exceeds 2× expected " + f"{expected_var:.6f} (d={d})" + ) + assert np.all(coord_vars > expected_var * 0.3), ( + f"Min coordinate variance {np.min(coord_vars):.6f} is below 0.3× expected " + f"{expected_var:.6f} (d={d})" + ) diff --git a/tests/test_turboquant.py b/tests/test_turboquant.py index 2ab098bc3..60c3e564c 100644 --- a/tests/test_turboquant.py +++ b/tests/test_turboquant.py @@ -172,8 +172,9 @@ class TestCompressedSizeBits: def test_size_calculation(self): tq = TurboQuant(d=128, bit_width=3, seed=42) bits = tq.compressed_size_bits(100) - # 100 vectors × (128 coords × 3 bits + 32 bits norm) = 100 × 416 = 41600 - assert bits == 100 * (128 * 3 + 32) + # 100 vectors × (128 coords × 3 bits + 64 bits norms) = 100 × 448 = 44800 + # 64 = vector_norm (32) + residual_norm (32), both stored in CompressedVector + assert bits == 100 * (128 * 3 + 64) def test_size_scales_with_vectors(self): tq = TurboQuant(d=64, bit_width=4, seed=42) @@ -188,11 +189,11 @@ class TestCompressionRatio: def test_3bit_compression(self): tq = TurboQuant(d=128, bit_width=3, seed=42) ratio = tq.compression_ratio(original_bits_per_value=16) - # 16 / (3 + 32/128) ≈ 16/3.25 ≈ 4.92 + # 16 / (3 + 64/128) = 16/3.5 ≈ 4.57 (64 = vector_norm + residual_norm) assert 4.0 < ratio < 6.0, f"3-bit compression ratio {ratio:.2f} unexpected" def test_4bit_compression(self): tq = TurboQuant(d=128, bit_width=4, seed=42) ratio = tq.compression_ratio(original_bits_per_value=16) - # 16 / (4 + 32/128) ≈ 16/4.25 ≈ 3.76 + # 16 / (4 + 64/128) = 16/4.5 ≈ 3.56 (64 = vector_norm + residual_norm) assert 3.0 < ratio < 5.0, f"4-bit compression ratio {ratio:.2f} unexpected" diff --git a/turboquant/kv_cache.py b/turboquant/kv_cache.py index 80c61f9cf..8ac854e76 100644 --- a/turboquant/kv_cache.py +++ b/turboquant/kv_cache.py @@ -47,9 +47,6 @@ class KVCacheCompressor: # Decompress k_hat, v_hat = compressor.decompress(compressed) - - # Or compress streaming (one token at a time) - compressor.compress_token(k_vec, v_vec, layer=0, head=0) """ def __init__( @@ -71,14 +68,20 @@ def __init__( self.k_bits = k_bits self.v_bits = v_bits + # Spawn independent child seeds so K and V quantizers use statistically + # independent random streams without magic offset arithmetic. + # Accept either an int or an already-created SeedSequence. + ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed) + k_child, v_child = ss.spawn(2) + # K cache uses full TurboQuant (inner product preservation) self.k_quantizer = TurboQuant( - head_dim, bit_width=k_bits, seed=seed, norm_correction=norm_correction, + head_dim, bit_width=k_bits, seed=k_child, norm_correction=norm_correction, ) # V cache uses MSE-only PolarQuant (value reconstruction) self.v_quantizer = TurboQuantMSE( - head_dim, bit_width=v_bits, seed=seed + 500, norm_correction=norm_correction, + head_dim, bit_width=v_bits, seed=v_child, norm_correction=norm_correction, ) def compress(self, k_cache: np.ndarray, v_cache: np.ndarray) -> CompressedKVCache: @@ -158,10 +161,10 @@ def memory_stats(self, seq_len: int, num_layers: int, num_heads: int) -> dict: n_vectors = num_layers * num_heads * seq_len original_bytes = n_vectors * self.head_dim * 2 # fp16 - # K: b bits per coord + 32-bit norm - k_bits_total = n_vectors * (self.head_dim * self.k_bits + 32) - # V: b bits per coord (no norm needed for MSE-only) - v_bits_total = n_vectors * self.head_dim * self.v_bits + # K: b bits per coord + 64-bit norms (TurboQuant stores ||x||_2 AND ||residual||_2) + k_bits_total = n_vectors * (self.head_dim * self.k_bits + 64) + # V: b bits per coord + 32-bit norm (TurboQuantMSE/PolarQuant stores ||x||_2 only) + v_bits_total = n_vectors * self.head_dim * self.v_bits + n_vectors * 32 compressed_bytes = (k_bits_total + v_bits_total) / 8 diff --git a/turboquant/turboquant.py b/turboquant/turboquant.py index 2a3bc884b..471e01e15 100644 --- a/turboquant/turboquant.py +++ b/turboquant/turboquant.py @@ -19,11 +19,11 @@ @dataclass class CompressedVector: """Container for a TurboQuant-compressed vector.""" - mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers - vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling - qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1} - residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2 - bit_width: int # total bits per coordinate + mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers + vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling + qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1} + residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2 + bit_width: int # total bits per coordinate class TurboQuant: @@ -54,13 +54,19 @@ def __init__(self, d: int, bit_width: int, seed: int = 42, norm_correction: bool self.d = d self.bit_width = bit_width + # Spawn independent child seeds from a SeedSequence so PolarQuant and QJL + # use statistically independent random streams without magic offset arithmetic. + # Accept either an int or an already-created SeedSequence (e.g. from a parent spawner). + ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed) + pq_child, qjl_child = ss.spawn(2) + # Stage 1: PolarQuant at (b-1) bits self.polar_quant = PolarQuant( - d, bit_width=bit_width - 1, seed=seed, norm_correction=norm_correction, + d, bit_width=bit_width - 1, seed=pq_child, norm_correction=norm_correction, ) - # Stage 2: QJL for residual (uses different seed) - self.qjl = QJL(d, seed=seed + 1000) + # Stage 2: QJL for residual (independent seed stream) + self.qjl = QJL(d, seed=qjl_child) def quantize(self, x: np.ndarray) -> CompressedVector: """Quantize a vector or batch. @@ -108,10 +114,11 @@ def compressed_size_bits(self, n_vectors: int) -> int: Includes: - PolarQuant indices: (b-1) bits per coordinate per vector - QJL signs: 1 bit per coordinate per vector - - Residual norms: 32 bits (float32) per vector + - Vector norms (||x||_2): 32 bits (float32) per vector + - Residual norms (||residual||_2): 32 bits (float32) per vector """ per_vector = self.d * self.bit_width # (b-1) + 1 bits per coordinate - norms = 32 # float32 per vector + norms = 64 # two float32 norms per vector (vector_norm + residual_norm) return n_vectors * (per_vector + norms) def compression_ratio(self, original_bits_per_value: int = 16) -> float: @@ -124,7 +131,8 @@ def compression_ratio(self, original_bits_per_value: int = 16) -> float: Compression ratio (e.g., 4.0 means 4× smaller). """ original_per_vector = self.d * original_bits_per_value - compressed_per_vector = self.d * self.bit_width + 32 # +32 for norm + # +64 = two float32 norms (||x||_2 and ||residual||_2) + compressed_per_vector = self.d * self.bit_width + 64 return original_per_vector / compressed_per_vector @@ -148,3 +156,14 @@ def quantize(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def dequantize(self, indices: np.ndarray, norms: np.ndarray) -> np.ndarray: return self.polar_quant.dequantize(indices, norms) + + def compressed_size_bits(self, n_vectors: int) -> int: + """Compute total storage in bits for n_vectors compressed vectors. + + Includes: + - PolarQuant indices: b bits per coordinate per vector + - Norms: 32 bits (float32) per vector (stored for per-vector rescaling) + """ + per_vector = self.d * self.bit_width + norms = 32 # float32 per vector + return n_vectors * (per_vector + norms)