diff --git a/.claude/AMX_GOTCHAS.md b/.claude/AMX_GOTCHAS.md index 79cd0110..1776104e 100644 --- a/.claude/AMX_GOTCHAS.md +++ b/.claude/AMX_GOTCHAS.md @@ -43,6 +43,7 @@ int8 2048³ = 169.7 GMAC/s, 600× scalar, single-thread | runs fine, `correct=false` | operand index/sign convention mirrored | Gotcha 12 | | compile error `unstable x86_amx_intrinsics` | used nightly intrinsics | Gotcha 1, 8 | | compile error `rbx is used internally by LLVM` | inline-asm CPUID | Gotcha 3 | +| exact when idle, silently wrong under CPU load (VM) | tile state lost across host vCPU switch | Gotcha 14 | --- @@ -222,6 +223,50 @@ once is correct even under rayon. `cpu_model()` is cached the same way. --- +## Gotcha 14: on oversubscribed VMs, tile state is silently corrupted under host CPU contention ⚑ + +Observed 2026-07-02 on this remote VM (4 vCPU, EMR-class Xeon, guest kernel +6.18.5) by `examples/onebrc_cascade_probe.rs`, reproduced on demand: + +``` +idle: 413/413 stations bit-exact (10M and 100M rows) +4 busy-loop competitors: 89/413, 152/413 exact — whole rows LOST, no fault +probe pinned to core 0, +load pinned to cores 1-3: 124/413 exact — pinning does NOT mitigate +idle control right after: 413/413 exact again +``` + +Signature: **no crash, no SIGSEGV/SIGILL — results are silently wrong, and +only under load.** An AVX-512 path in the same process, same run, stays +bit-exact, isolating the corruption to TMM tile state (the tmm0 accumulator +loses in-flight partial sums). Because guest-side pinning doesn't help, the +suspected mechanism is the **host** hypervisor's vCPU context switch failing +to save/restore guest `XTILEDATA` when the host multiplexes oversubscribed +pCPUs (idle guests keep their vCPUs resident → no corruption; loaded guests +get switched → corruption). Guest-side `arch_prctl` permission (Gotcha 4) is +correctly granted — this is a layer below the guest kernel. + +Consequences: + +- **Never certify AMX numerics from a shared/oversubscribed VM.** Bare metal + or a dedicated-CPU instance only. A "PASS on my cloud box" is worthless + under this gotcha unless the box was provably idle. +- **Extend Gotcha 9's discipline**: a parity test for a tile kernel must ALSO + run under deliberate CPU contention (a few busy loops are enough — see the + reproduction above). Exact-when-idle is necessary, not sufficient. +- **Keep tile residency short.** Long accumulation loops that live in tmm + across many iterations (the 16×16×K pattern holds tmm0 for K/32 iterations) + maximize the exposure window. Draining accumulators to memory more often + shrinks it but does NOT close it — treat it as harm reduction, not a fix. +- Production dispatch on virtualized hosts should either avoid AMX or pair it + with a checksum/parity channel (e.g. a redundant ones-row whose expected + value is known — the onebrc probe's count row doubles as exactly that). + +Fault signature: `correct=true` in every quiet test, sporadic wrong results +in production under load, AVX-512 siblings unaffected. + +--- + ## Hardware tiers ``` diff --git a/.claude/blackboard.md b/.claude/blackboard.md index 7bafad5f..61d0e1fc 100644 --- a/.claude/blackboard.md +++ b/.claude/blackboard.md @@ -3,6 +3,114 @@ > **Read this first.** The "Polyglot Notebook" architecture below is a > separate/older program, not the current epoch. +## 2026-07-02 (later) — bf16 tile GEMM: VDPBF16PS middle tier + PackedBf16B (loose end closed) + +Closed the [LOOSE END] from the 1BRC entry below. `hpc/bf16_tile_gemm.rs` +is now a three-tier ladder — **AMX TDPBF16PS → AVX-512 VDPBF16PS → +decode+FMA polyfill** — with the polyfill kernel (`simd_ops.rs`) untouched: + +- **VDPBF16PS tier** (`avx512bf16_path`, private): bf16 pairs multiplied + natively per zmm (no bf16→f32 decode), f32 lane accumulators, SAME VNNI + operand layout as the AMX tile → one packed buffer serves both tile + tiers. `_mm512_dpbf16_ps` verified stable on Rust 1.94. Runtime + `is_x86_feature_detected!("avx512bf16")` (EMR box has it). +- **`PackedBf16B`** + **`bf16_tile_gemm_16x16_packed`**: VNNI pack (and + its per-call allocation) hoisted out of hot loops; `vnni_index(row,col) + = (row/2)·32 + 2·col + (row&1)` supports staging B DIRECTLY in VNNI + layout (zero pack cost — the right shape for one-hot/sparse staging). +- **`bf16_tile_gemm_tier()`**: names the tier that will run (Gotcha 9 + reporting). Re-exports via `ndarray::simd::*` (W1a surface). +- **Exactness boundary preserved (operator condition):** bit-exact across + ALL tiers for bf16-exact integer operands with accumulation < 2^24 — + asserted with `assert_eq!` in the new parity tests (vnni_index vs + vnni_pack_bf16; packed==unpacked==i64 reference; VDPBF16PS exact + + tolerance-parity vs polyfill on floats; accumulate semantics). Gotcha-14 + contention parity test included as `#[ignore]` (fails on oversubscribed + VMs BY DESIGN; run `--ignored` on dedicated silicon). + +[MEASURED] onebrc probe GEMM leg with direct-VNNI staging: **3.6 → 21.3 +Mrows/s (5.9×), 23.7 → 141.9 GMAC/s** (single thread — near the 169.7 +GMAC/s int8 AMX anchor in AMX_GOTCHAS). 413/413 stations still EXACT; +8/8 lib tests + 2 doctests green; clippy/fmt clean. + +[NOTE] Dispatch-behavior change signed off by operator: the row-major +entry `bf16_tile_gemm_16x16` now routes avx512bf16-without-AMX hosts +through VDPBF16PS instead of decode+FMA (bit-exact within the integer +boundary; BF16-precision-class accumulation-order differences on general +floats, same as any tier change). + +[ADDED, same day] **LE byte contract on `PackedBf16B`** (operator "Go" — +first brick of the SoA-Morton batch-writer / write-hiding design): +`as_le_bytes()` (zero-cost reinterpret; LE by construction — the module +is x86_64-only) + `from_le_bytes()` (endian-correct anywhere, plain copy +on LE). This is the persistence/mailbox face per lance-graph's +SoaEnvelope discipline (envelope bytes LE from creation to tombstone). +Test `le_byte_view_roundtrips_and_is_truly_le` asserts byte 2i = low +byte of lane i AND that a GEMM over the roundtripped buffer stays +bit-exact. 9/9 lib tests green. Next bricks (lance-graph side): batch +writer flushing tile buffers as envelope tenants; write-hiding = stage +morsel N+1's VNNI writes while morsel N's tiles compute. + +## 2026-07-02 — 1BRC-on-substrate probe (`examples/onebrc_cascade_probe.rs`) + +1BRC workload (min/mean/max per station) restated on the substrate, as a +sibling of `morton_cascade_probe`. Branch `claude/1brc-lance-graph-xfx5tu`. +Three paths certified bit-for-bit against a scalar integer reference +(413 stations, integer tenths → exact in f32/f64 by construction): + +- **Morton scatter**: stations minted as cells on a 64×64 Morton grid + (4×4 tile = one F32x16), morsel-batched (64K rows) scatter into + L1-resident SoA accumulators, (min,max,Σ,n) monoid fold. +- **AMX BF16 tile-GEMM group-by**: (Σ,n) as `C += A[16×K]·B[K×16]` via + the NEW `ndarray::simd::bf16_tile_gemm_16x16_amx` re-export (W1a: the + AMX-dispatching hpc wrapper surfaced through the canonical polyfill, + same pattern as `matmul_i8_to_i32`; the `_amx` suffix disambiguates + from the pure-FMA `simd::bf16_tile_gemm_16x16`) — B = per-row one-hot + station indicator (26 column-blocks of 16), A rows = {1, hi(t), lo(t), + bf16-RNE(t)} with the exactness split `hi=(t/256)·256, lo=t−hi` (both + bf16-exact; f32 tile accumulation exact for K=4096). Clear-by-undo + keeps B staging O(rows). AMX **actually ran** (amx_available()==true + printed per Gotcha 9 discipline; EMR-class Xeon, kernel 6.18.5). +- **Aggregate pyramid** over the tile grid: hierarchical (min,mean,max) + per tile/region/root in the same pass + band-prune queries + (Belichtungsmesser on the MIN channel). + +[MEASURED] 10M rows, 4-core Xeon EMR VM, single thread: +reference 453 Mrows/s | morton scatter 443 Mrows/s (**substrate tax ≈ 2%**) +| tile-GEMM 3.6 Mrows/s = 23.7 GMAC/s (dense one-hot indicator = the +honest price of group-by-as-matmul; per-call `vnni_pack_bf16` alloc in +`bf16_tile_gemm_16x16` is a visible overhead) | pyramid fold 0.02 ms | +band query prune 90.2%. All 413 stations EXACT on both paths; PASS. +Also EXACT at 100M rows (idle). **"Is BF16 precise enough?" — measured:** +the naive bf16-RNE row through the same tile gives max per-station +|Δmean| = 0.0123 tenths (0.0012 °C, N≈24k/station — quantization bias +averages out); single readings off by ≤ 2 tenths (half-ulp of bf16 at +|t|∈[512,1024)). Verdict: bf16-direct fine for means, hi/lo split (free — +spare A rows) required for min/max + exactness certification. + +[FINDING → **Gotcha 14**, `.claude/AMX_GOTCHAS.md`] On this oversubscribed +VM, **AMX tile state silently corrupts under host CPU contention**: idle += 413/413 exact at 100M rows; with 4 busy-loop competitors = 89-152/413 +(whole rows lost, no fault); guest-side core pinning does NOT mitigate +(124/413); AVX-512 scatter path in the same run stays exact → isolated +to TMM state; suspected host-vCPU-switch XTILEDATA loss. Consequences +written into the gotcha: never certify AMX numerics on shared VMs; parity +tests must also run under deliberate load (Gotcha 9 extension); short +tile residency = harm reduction only. + +[CROSS-REPO] Algebraic certification (partition/regroup invariance of the +monoid fold, bf16 hi/lo decomposition exactness) lands as a diagnostic +probe in `lance-graph/crates/jc` (`onebrc_agg`) — kernels here, proof +there, per the architecture rule (ndarray = hardware, jc = proof). + +[LOOSE END] AMX has no min/max tile op → min/max stay on the scatter +path by construction. `bf16_tile_gemm_16x16` allocates + VNNI-packs B on +every call — a pre-packed-B variant would lift the GEMM leg +substantially; file under W1-adjacent if the group-by-as-GEMM shape +recurs. Text-ingest leg (SWAR/SIMD parse of the 13 GB file) deliberately +NOT probed here — separate probe if pursued (would exercise +`byte_scan.rs`). + ## 2026-06-28 — WASM SIMD128 backend filled in (`src/simd_wasm.rs`) Replaced the commented-out scaffolding in `src/simd_wasm.rs` with a real diff --git a/Cargo.toml b/Cargo.toml index 8098710b..2a0b660a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,12 @@ required-features = ["std"] name = "morton_cascade_probe" required-features = ["std"] +# 1BRC-on-substrate probe: Morton group-by + AMX BF16 tile-GEMM mean path +# (imports `ndarray::simd` + `ndarray::hpc`, both std-gated). +[[example]] +name = "onebrc_cascade_probe" +required-features = ["std"] + [[example]] name = "golden_helix_probe" required-features = ["std"] diff --git a/examples/onebrc_cascade_probe.rs b/examples/onebrc_cascade_probe.rs new file mode 100644 index 00000000..1ee9c79b --- /dev/null +++ b/examples/onebrc_cascade_probe.rs @@ -0,0 +1,570 @@ +//! 1BRC-on-substrate probe — min/mean/max per station as a monoid group-by +//! over the gridlake Morton substrate, with an AMX BF16 tile-GEMM mean path +//! and a Belichtungsmesser band-prune demo on the aggregate pyramid. +//! +//! ## What this validates (probe-first, sibling of `morton_cascade_probe`) +//! +//! The One Billion Row Challenge workload (station;temperature → min/mean/max +//! per station) restated on this substrate: +//! +//! - **Stations are addresses, not hash keys.** Each station is minted a cell +//! on a `(4T)²` Morton grid (4×4 leaf tile = one `F32x16` = 64 bytes), so +//! the per-station accumulators live at canonical Z-order addresses in a +//! `MultiLaneColumn`-backed SoA — the gridlake carrier, not a hashmap. +//! - **Aggregation is a commutative monoid fold** `(min, max, Σ, n)` — the +//! same algebra blasgraph's semiring folds rely on. Morsel-batched scatter +//! (64K rows/morsel) into L1-resident accumulators is the "blasgraph-like +//! cache algorithm": the hot set is `n_stations × 24 B ≈ 10 KB`, so the +//! scatter never leaves L1; the Morton layout makes the pyramid fold +//! cache-oblivious. +//! - **Group-by as tile GEMM (the AMX leg).** `(Σ, n)` per station is also a +//! matmul: `C[16×16] += A[16×K] · B[K×16]` with B a per-row one-hot station +//! indicator (stations in column-blocks of 16) and three live A rows — +//! row 0 = 1 (count), row 1 = hi(temp), row 2 = lo(temp). BF16 has an 8-bit +//! significand, so integer tenths in [-999, 999] are NOT bf16-exact; the +//! split `hi = (t/256)·256, lo = t − hi` makes every operand exact in BF16 +//! (hi ∈ {0, ±256, ±512, ±768}, |lo| ≤ 255 < 2^8), and per-tile f32 +//! accumulation stays exact (≤ K·999 < 2^24 for K = 4096). A-row 3 carries +//! naive bf16-RNE temps through the SAME tile — the "is BF16 precise +//! enough?" experiment, measured instead of argued (the extra row is free). +//! Runs through `ndarray::simd::bf16_tile_gemm_16x16_amx` (TDPBF16PS when +//! `amx_available()`, the F32x16 FMA polyfill otherwise) — all imports via +//! the canonical `ndarray::simd::*` surface, per the W1a consumer contract. +//! Per AMX Gotcha 9 ("a skipped test is not a passing test") the probe +//! PRINTS which tier actually ran. min/max stay on the scatter path — AMX +//! has no min/max tile op (TDPBF16PS/TDPBUSD are dot-product accumulates +//! only). +//! +//! ⚠ Gotcha 14 (DISCOVERED BY THIS PROBE, 2026-07-02): on this oversubscribed +//! VM, AMX tile state is silently corrupted under host CPU contention — +//! idle runs are bit-exact at 100M rows; with 4 busy-loop competitors the +//! GEMM leg drops whole rows (89-152/413 stations exact), and pinning the +//! probe to an uncontended core does NOT help. The scatter path (AVX-512) +//! stays exact under the same load, isolating the corruption to TMM state. +//! A FAIL of the GEMM leg on a loaded box is the probe working as designed. +//! See `.claude/AMX_GOTCHAS.md` § Gotcha 14. +//! - **The cascade is a reduction pyramid here, not a search cascade.** Every +//! row must be touched (nothing to skip on input); what the pyramid buys is +//! free *hierarchical* aggregates — min/mean/max per tile / region / root in +//! the same pass — plus band-pruned queries over the result ("which stations +//! have min ≤ q" visits only intersecting subtrees). +//! - **Exactness is provable, not approximate.** Temperatures are integer +//! tenths in [-999, 999] → exact in `f32`; sums of integer tenths stay +//! < 2^53 → exact in `f64`. Both substrate paths must match the scalar +//! reference bit-for-bit. The algebraic side (partition/regroup invariance +//! of the monoid fold, the BF16 hi/lo decomposition) is certified +//! independently in `lance-graph/crates/jc` (`onebrc_agg` probe) — +//! kernels here, proof there. +//! +//! cargo run --release --example onebrc_cascade_probe +//! ONEBRC_ROWS=100000000 cargo run --release --example onebrc_cascade_probe +//! +//! PASS: both substrate paths (Morton scatter; AMX/BF16 tile GEMM) match the +//! scalar reference exactly for every station, root invariants hold, and the +//! band-prune query returns the brute-force station set. Throughput and +//! prune-rate lines are the measured "boost". + +use std::sync::Arc; +use std::time::Instant; + +use ndarray::simd::{bf16_tile_gemm_16x16_packed, bf16_tile_gemm_tier, F64x8, MultiLaneColumn, PackedBf16B}; + +// ── Deterministic RNG (same SplitMix64 as jc / hpc::pillar) ───────────────── + +const SEED: u64 = 0x1BC_0FFEE; + +fn splitmix64(state: &mut u64) -> u64 { + *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = *state; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) +} + +// ── Morton addressing (identical to morton_cascade_probe) ─────────────────── + +fn morton2d(x: u32, y: u32, bits: u32) -> u32 { + let mut m = 0u32; + for b in 0..bits { + m |= ((x >> b) & 1) << (2 * b); + m |= ((y >> b) & 1) << (2 * b + 1); + } + m +} + +/// cell (x,y) on a (4T)² grid → flat Morton index with 4×4-tile granularity: +/// tile (x>>2, y>>2) ordered by morton2d, 16 cells per tile ordered by +/// morton2d(x&3, y&3, 2). Each tile is one contiguous 16-lane chunk. +fn cell_index(x: u32, y: u32, k: u32) -> usize { + let (tx, ty) = (x >> 2, y >> 2); + let (ix, iy) = (x & 3, y & 3); + (morton2d(tx, ty, k) as usize) * 16 + morton2d(ix, iy, 2) as usize +} + +// ── Workload parameters ───────────────────────────────────────────────────── + +const T: u32 = 16; // tiles per side → grid 64×64 = 4096 cells, 256 tiles +const N_STATIONS: usize = 413; // the classic 1BRC station count +const MORSEL: usize = 1 << 16; // 64K rows per batch (scatter path) +const GEMM_K: usize = 4096; // tile-GEMM sub-morsel (multiple of 32) +const N_GROUPS: usize = N_STATIONS.div_ceil(16); // 16-station column blocks + +// ── Scalar reference (ground truth, integer domain) ───────────────────────── + +#[derive(Clone, Copy)] +struct RefAgg { + min_t: i16, // tenths + max_t: i16, + sum_t: i64, + cnt: u64, +} + +impl RefAgg { + const IDENTITY: RefAgg = RefAgg { + min_t: i16::MAX, + max_t: i16::MIN, + sum_t: 0, + cnt: 0, + }; +} + +// ── Substrate accumulators (SoA over Morton cells) ────────────────────────── + +struct MortonAgg { + min_c: Vec, // +INF identity + max_c: Vec, // -INF identity + sum_c: Vec, // integer tenths, exact while < 2^53 + cnt_c: Vec, +} + +impl MortonAgg { + fn new(n_cells: usize) -> Self { + MortonAgg { + min_c: vec![f32::INFINITY; n_cells], + max_c: vec![f32::NEG_INFINITY; n_cells], + sum_c: vec![0.0; n_cells], + cnt_c: vec![0; n_cells], + } + } +} + +/// One pyramid node: the same (min, max, Σ, n) monoid element, per subtree. +#[derive(Clone, Copy)] +struct Node { + min: f32, + max: f32, + sum: f64, + cnt: u64, +} + +/// Aggregate pyramid over the T² tiles in Morton order. Level 0 = per-tile +/// fold of 16 cells; level l = fold of 4 level-(l−1) nodes; root = global. +struct Pyramid { + levels: Vec>, + k: u32, +} + +impl Pyramid { + /// Level-0 min/max folds go through the gridlake carrier: 16 cells = + /// one 64-byte `F32x16` chunk of a `MultiLaneColumn`, reduced in-register. + fn build(agg: &MortonAgg, t: u32) -> Self { + let k = t.trailing_zeros(); + let n_tiles = (t * t) as usize; + + // Wrap the min/max channels in the SoA byte carrier (LE f32 lanes). + let min_col = column_from_f32(&agg.min_c); + let max_col = column_from_f32(&agg.max_c); + + let mut lvl0 = Vec::with_capacity(n_tiles); + for (tile, (min_v, max_v)) in min_col.iter_f32x16().zip(max_col.iter_f32x16()).enumerate() { + let base = tile * 16; + let mut sum = 0.0f64; + let mut cnt = 0u64; + for c in 0..16 { + sum += agg.sum_c[base + c]; + cnt += agg.cnt_c[base + c]; + } + // Two F64x8 loads cross-check the scalar Σ (exact: integer tenths). + let s_lo = F64x8::from_array(agg.sum_c[base..base + 8].try_into().unwrap()); + let s_hi = F64x8::from_array(agg.sum_c[base + 8..base + 16].try_into().unwrap()); + debug_assert_eq!(s_lo.reduce_sum() + s_hi.reduce_sum(), sum); + let _ = (s_lo, s_hi); + lvl0.push(Node { + min: min_v.reduce_min(), + max: max_v.reduce_max(), + sum, + cnt, + }); + } + + let mut levels = vec![lvl0]; + for l in 1..=k as usize { + let prev = &levels[l - 1]; + let mut cur = Vec::with_capacity(prev.len() / 4); + for q in prev.chunks_exact(4) { + cur.push(Node { + min: q.iter().map(|n| n.min).fold(f32::INFINITY, f32::min), + max: q.iter().map(|n| n.max).fold(f32::NEG_INFINITY, f32::max), + sum: q.iter().map(|n| n.sum).sum(), + cnt: q.iter().map(|n| n.cnt).sum(), + }); + } + levels.push(cur); + } + Pyramid { levels, k } + } + + fn root(&self) -> Node { + self.levels[self.k as usize][0] + } + + /// Band-prune descent on the MIN channel: visit only subtrees whose + /// min ≤ q; return (leaf tiles visited, matching cell indices). + fn stations_with_min_le(&self, q: f32, agg: &MortonAgg) -> (usize, Vec) { + let mut visited = 0usize; + let mut hits = Vec::new(); + let mut stack = vec![(self.k as usize, 0usize)]; + while let Some((level, node)) = stack.pop() { + if self.levels[level][node].min > q { + continue; // whole subtree pruned + } + if level == 0 { + visited += 1; + let base = node * 16; + for c in 0..16 { + if agg.min_c[base + c] <= q { + hits.push(base + c); + } + } + } else { + let base = node * 4; + for c in 0..4 { + stack.push((level - 1, base + c)); + } + } + } + hits.sort_unstable(); + (visited, hits) + } +} + +fn column_from_f32(vals: &[f32]) -> MultiLaneColumn { + let raw: Vec = vals.iter().flat_map(|v| v.to_le_bytes()).collect(); + MultiLaneColumn::new(Arc::from(raw.into_boxed_slice())).unwrap() +} + +// ── AMX BF16 tile-GEMM mean path ──────────────────────────────────────────── + +/// f32 → bf16 by truncation. Every value fed through here is exactly +/// representable in BF16 (0, ±1, hi multiples of 256 ≤ 768, |lo| ≤ 255), +/// so truncation is lossless and equals round-to-nearest-even. +#[inline(always)] +fn bf16_exact(v: f32) -> u16 { + let bits = v.to_bits(); + debug_assert_eq!(bits & 0xFFFF, 0, "value not bf16-exact: {v}"); + (bits >> 16) as u16 +} + +const BF16_ONE: u16 = 0x3F80; // 1.0f32 >> 16 + +/// f32 → bf16 with round-to-nearest-even — the conversion a naive "just store +/// the temperature as BF16" pipeline would use. NOT exact for |tenths| > 255; +/// A-row 3 measures exactly how much that costs (the "is BF16 precise enough?" +/// experiment — see module doc). +#[inline(always)] +fn bf16_rne(v: f32) -> u16 { + let bits = v.to_bits(); + ((bits.wrapping_add(0x7FFF).wrapping_add((bits >> 16) & 1)) >> 16) as u16 +} + +/// Group-by as tile GEMM over one sub-morsel of `rows` (≤ GEMM_K) rows. +/// +/// A[16×K] (bf16, row-major): row 0 = 1.0 (count), row 1 = hi(temp), +/// row 2 = lo(temp), row 3 = bf16-RNE(temp) (the naive-precision experiment), +/// rows 4-15 = 0. B_blocks[g][K×16] (bf16, row-major): per-row one-hot +/// indicator for station group g (stations g·16 ..= g·16+15). +/// C[i][j] = Σ_r A[i][r]·B[r][j] gives per station j of group g: +/// C[0][j] = n, C[1][j] = Σhi, C[2][j] = Σlo — all exact (see module doc) — +/// and C[3][j] = Σ bf16(temp), whose deviation from Σhi+Σlo is the measured +/// cost of skipping the hi/lo split. The extra row is free: same tile, same +/// GEMM call. +/// +/// B blocks are zeroed ONCE at allocation; each sub-morsel sets exactly one +/// entry per row and clears the same entry afterwards (clear-by-undo), so the +/// per-morsel cost is O(rows), not O(rows × groups) — the same L1-resident +/// hot-set discipline as the scatter path. +struct GemmGroupBy { + a: Vec, // 16 × GEMM_K, rows 0 and 4-15 pre-set + b_blocks: Vec, // N_GROUPS × VNNI-packed (GEMM_K × 16) + c: Vec, // 16 × 16 output tile + sum_t: Vec, // per-station Σ (tenths), drained exactly + cnt: Vec, // per-station n + sum_bf16: Vec, // per-station Σ of bf16-rounded temps (row 3) +} + +impl GemmGroupBy { + fn new() -> Self { + let mut a = vec![0u16; 16 * GEMM_K]; + a[..GEMM_K].fill(BF16_ONE); // row 0 = ones → counts + GemmGroupBy { + a, + b_blocks: (0..N_GROUPS).map(|_| PackedBf16B::zeroed(GEMM_K)).collect(), + c: vec![0.0f32; 256], + sum_t: vec![0i64; N_STATIONS], + cnt: vec![0u64; N_STATIONS], + sum_bf16: vec![0.0f64; N_STATIONS], + } + } + + fn fold_sub_morsel(&mut self, rows: &[(u16, i16)]) { + debug_assert!(rows.len() <= GEMM_K); + // Stage A rows 1-2 (hi/lo split) and the one-hot B entries — written + // DIRECTLY in VNNI layout via `PackedBf16B::vnni_index`, so the tile + // tiers never pack (and never allocate) per call. Rows of a partial + // sub-morsel beyond `rows.len()` keep stale A values — their B + // indicator is never set, so they contribute exact zeros. + for (r, &(sid, temp)) in rows.iter().enumerate() { + let hi = (temp as i32 / 256) * 256; + let lo = temp as i32 - hi; + self.a[GEMM_K + r] = bf16_exact(hi as f32); + self.a[2 * GEMM_K + r] = bf16_exact(lo as f32); + self.a[3 * GEMM_K + r] = bf16_rne(temp as f32); + let (g, j) = (sid as usize / 16, sid as usize % 16); + self.b_blocks[g].data_mut()[PackedBf16B::vnni_index(r, j)] = BF16_ONE; + } + + for g in 0..N_GROUPS { + self.c.fill(0.0); // gemm ACCUMULATES; each group starts clean + bf16_tile_gemm_16x16_packed(&self.a, &self.b_blocks[g], &mut self.c); + for j in 0..16 { + let s = g * 16 + j; + if s >= N_STATIONS { + break; + } + // Every C entry in rows 0-2 is an exact integer (see module + // doc), so the drain into i64/u64 is lossless. Row 3 carries + // the bf16-quantized sums (exact f32 sums of INEXACT inputs). + self.cnt[s] += self.c[j] as u64; + self.sum_t[s] += (self.c[16 + j] as f64 + self.c[32 + j] as f64) as i64; + self.sum_bf16[s] += self.c[48 + j] as f64; + } + } + + // Clear-by-undo: reset exactly the B entries this sub-morsel set. + for (r, &(sid, _)) in rows.iter().enumerate() { + let (g, j) = (sid as usize / 16, sid as usize % 16); + self.b_blocks[g].data_mut()[PackedBf16B::vnni_index(r, j)] = 0; + } + } +} + +// ── Probe ─────────────────────────────────────────────────────────────────── + +fn main() { + let rows: usize = std::env::var("ONEBRC_ROWS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(10_000_000); + + let side = 4 * T; + let n_cells = (side * side) as usize; + let k = T.trailing_zeros(); + + println!("== 1BRC cascade probe (Morton {side}×{side}, {N_STATIONS} stations, {rows} rows) ==\n"); + + // 1. Mint station addresses: distinct cells on the Morton grid, plus a + // true mean per station (integer tenths in [-400, 400]). + let mut st = SEED; + let mut taken = vec![false; n_cells]; + let mut station_cell = Vec::with_capacity(N_STATIONS); + let mut station_mean_t = Vec::with_capacity(N_STATIONS); + while station_cell.len() < N_STATIONS { + let x = (splitmix64(&mut st) % side as u64) as u32; + let y = (splitmix64(&mut st) % side as u64) as u32; + let idx = cell_index(x, y, k); + if !taken[idx] { + taken[idx] = true; + station_cell.push(idx); + station_mean_t.push((splitmix64(&mut st) % 801) as i16 - 400); + } + } + + // 2. Morsel-batched folds: generate rows deterministically, feed the SAME + // morsel buffer to all three paths (scalar reference; Morton scatter; + // tile-GEMM group-by). Data-flow: read-only morsel slices in, owned + // accumulators, no shared `&mut` during the fold. + let mut reference = vec![RefAgg::IDENTITY; N_STATIONS]; + let mut agg = MortonAgg::new(n_cells); + let mut gemm = GemmGroupBy::new(); + let mut morsel: Vec<(u16, i16)> = Vec::with_capacity(MORSEL); + + let (mut t_ref, mut t_sub, mut t_gemm) = (0.0f64, 0.0f64, 0.0f64); + let mut produced = 0usize; + while produced < rows { + let batch = MORSEL.min(rows - produced); + morsel.clear(); + for _ in 0..batch { + let sid = (splitmix64(&mut st) % N_STATIONS as u64) as u16; + let noise = (splitmix64(&mut st) % 201) as i16 - 100; // ±10.0 °C + let temp = (station_mean_t[sid as usize] + noise).clamp(-999, 999); + morsel.push((sid, temp)); + } + produced += batch; + + let t0 = Instant::now(); + for &(sid, temp) in &morsel { + let r = &mut reference[sid as usize]; + r.min_t = r.min_t.min(temp); + r.max_t = r.max_t.max(temp); + r.sum_t += temp as i64; + r.cnt += 1; + } + t_ref += t0.elapsed().as_secs_f64(); + + let t1 = Instant::now(); + for &(sid, temp) in &morsel { + let idx = station_cell[sid as usize]; + let v = temp as f32; + agg.min_c[idx] = agg.min_c[idx].min(v); + agg.max_c[idx] = agg.max_c[idx].max(v); + agg.sum_c[idx] += temp as f64; + agg.cnt_c[idx] += 1; + } + t_sub += t1.elapsed().as_secs_f64(); + + let t2 = Instant::now(); + for sub in morsel.chunks(GEMM_K) { + gemm.fold_sub_morsel(sub); + } + t_gemm += t2.elapsed().as_secs_f64(); + } + + // 3. Aggregate pyramid (hierarchical min/mean/max for free). + let t3 = Instant::now(); + let pyr = Pyramid::build(&agg, T); + let t_pyr = t3.elapsed().as_secs_f64(); + + // 4a. Certify Morton-scatter path == reference, bit-for-bit, per station. + let mut scatter_mism = 0usize; + for (s, r) in reference.iter().enumerate() { + let idx = station_cell[s]; + let ok = agg.min_c[idx] == r.min_t as f32 + && agg.max_c[idx] == r.max_t as f32 + && agg.sum_c[idx] == r.sum_t as f64 + && agg.cnt_c[idx] == r.cnt; + if !ok { + scatter_mism += 1; + if scatter_mism <= 3 { + println!( + " SCATTER MISMATCH station {s}: sub(min={} max={} sum={} n={}) ref(min={} max={} sum={} n={})", + agg.min_c[idx], agg.max_c[idx], agg.sum_c[idx], agg.cnt_c[idx], r.min_t, r.max_t, r.sum_t, r.cnt + ); + } + } + } + + // 4b. Certify tile-GEMM path (Σ, n) == reference, exactly, per station. + let mut gemm_mism = 0usize; + for (s, r) in reference.iter().enumerate() { + if gemm.cnt[s] != r.cnt || gemm.sum_t[s] != r.sum_t { + gemm_mism += 1; + if gemm_mism <= 3 { + println!( + " GEMM MISMATCH station {s}: gemm(sum={} n={}) ref(sum={} n={})", + gemm.sum_t[s], gemm.cnt[s], r.sum_t, r.cnt + ); + } + } + } + + // 4c. The "is BF16 precise enough?" measurement: A-row 3 carried naive + // bf16-RNE temperatures through the same tile; compare the resulting + // per-station means against the exact (hi/lo-split) means. + let mut max_mean_err_t = 0.0f64; // tenths + for (s, r) in reference.iter().enumerate() { + if r.cnt == 0 { + continue; + } + let exact = r.sum_t as f64 / r.cnt as f64; + let naive = gemm.sum_bf16[s] / r.cnt as f64; + max_mean_err_t = max_mean_err_t.max((naive - exact).abs()); + } + + // Root invariants: count, global min/max, global Σ. + let root = pyr.root(); + let g_min = reference.iter().map(|r| r.min_t).min().unwrap(); + let g_max = reference.iter().map(|r| r.max_t).max().unwrap(); + let g_sum: i64 = reference.iter().map(|r| r.sum_t).sum(); + let root_ok = + root.cnt == rows as u64 && root.min == g_min as f32 && root.max == g_max as f32 && root.sum == g_sum as f64; + + // 5. Band-prune query on the pyramid: stations with min ≤ q. + let q = (g_min + 50) as f32; // a band 5.0 °C above the coldest reading + let (visited, hits) = pyr.stations_with_min_le(q, &agg); + let mut brute: Vec = (0..N_STATIONS) + .filter(|&s| reference[s].min_t as f32 <= q) + .map(|s| station_cell[s]) + .collect(); + brute.sort_unstable(); + let query_ok = hits == brute; + let n_tiles = (T * T) as usize; + let prune = 100.0 * (1.0 - visited as f64 / n_tiles as f64); + + // 6. Report (PillarReport style: deterministic seed, measured vs expected). + let pass = scatter_mism == 0 && gemm_mism == 0 && root_ok && query_ok; + let tier = bf16_tile_gemm_tier(); + println!(" seed=0x{SEED:X} stations={N_STATIONS} rows={rows}"); + println!( + " morton scatter (min,max,Σ,n): {}/{} exact → {}", + N_STATIONS - scatter_mism, + N_STATIONS, + if scatter_mism == 0 { "EXACT" } else { "MISMATCH" } + ); + println!( + " bf16 tile-GEMM (Σ,n) [{tier}]: {}/{} exact → {}", + N_STATIONS - gemm_mism, + N_STATIONS, + if gemm_mism == 0 { "EXACT" } else { "MISMATCH" } + ); + println!( + " bf16-direct row (no hi/lo split): max |Δmean| = {:.4} tenths = {:.5} °C \ + (single reading off by ≤ 2 tenths: half-ulp of bf16 at |t| ∈ [512, 1024), certified in jc)", + max_mean_err_t, + max_mean_err_t / 10.0 + ); + println!( + " root invariants (n={}, min={:.1}°C, max={:.1}°C): {}", + root.cnt, + root.min / 10.0, + root.max / 10.0, + if root_ok { "OK" } else { "WRONG" } + ); + println!( + " band query min ≤ {:.1}°C: {} stations, visited {visited}/{n_tiles} tiles → prune {prune:.1}% {}", + q / 10.0, + hits.len(), + if query_ok { "OK" } else { "WRONG" } + ); + println!( + " scatter: reference {:.0} Mrows/s | morton {:.0} Mrows/s | tile-GEMM {:.1} Mrows/s | pyramid {:.2} ms", + rows as f64 / t_ref / 1e6, + rows as f64 / t_sub / 1e6, + rows as f64 / t_gemm / 1e6, + t_pyr * 1e3 + ); + // Effective MAC rate of the GEMM formulation (dense-indicator overhead + // is the honest price of group-by-as-matmul: N_GROUPS·16·16·K MACs per + // K-row sub-morsel ≈ 6.7 kMAC/row). + let macs = rows as f64 * (N_GROUPS * 256) as f64; + println!( + " tile-GEMM effective rate: {:.1} GMAC/s (dense one-hot indicator, {} groups)", + macs / t_gemm / 1e9, + N_GROUPS + ); + println!( + " hierarchical bonus: {} pyramid levels of regional (min,mean,max) in the same pass", + pyr.levels.len() + ); + println!("\n{}", if pass { "✓ PASS" } else { "✗ FAIL" }); + std::process::exit(i32::from(!pass)); +} diff --git a/src/hpc/bf16_tile_gemm.rs b/src/hpc/bf16_tile_gemm.rs index cc951852..3a2d2c8f 100644 --- a/src/hpc/bf16_tile_gemm.rs +++ b/src/hpc/bf16_tile_gemm.rs @@ -1,21 +1,29 @@ -//! BF16 tile GEMM polyfill — AMX (TDPBF16PS) with AVX-512 F32x16 fallback. +//! BF16 tile GEMM — three-tier runtime dispatch, no decode on the tile tiers. //! -//! Same API, runtime tier dispatch via `amx_available()`. The AMX path uses -//! the raw primitives in `hpc::amx_matmul`. The fallback decodes BF16→f32 -//! and uses `crate::simd::F32x16` + `mul_add` (VFMADD231PS on AVX-512, -//! emulated as 2× F32x8 FMA on AVX2). +//! Tier ladder (selected per host, report via [`bf16_tile_gemm_tier`]): +//! 1. **AMX `TDPBF16PS`** — bf16×bf16 tile pairs, f32 tile accumulator +//! (raw primitives in `hpc::amx_matmul`). +//! 2. **AVX-512 `VDPBF16PS`** — bf16×bf16 register pairs, f32 zmm +//! accumulator. Same VNNI operand layout as tier 1, no bf16→f32 decode. +//! 3. **Decode + `F32x16` FMA polyfill** (`crate::simd::bf16_tile_gemm_16x16`) +//! — bf16→f32 decode then FMA; the polyfill owns AVX-512/AVX2/NEON/scalar. //! -//! Pattern: one dispatch check per call; caller supplies preallocated -//! output and (for AMX) VNNI-packed B. +//! All tiers accumulate (`C += A·B`) and are **bit-exact for bf16-exact +//! integer operands** with accumulation below 2^24 (bf16 products are exact +//! in f32) — asserted by the parity tests below with `assert_eq!`, not +//! tolerance. For general float operands the tiers agree up to accumulation +//! order (BF16-precision class). //! -//! Tile shape: M=16, N=16, K = multiple of 32. +//! Tile shape: M=16, N=16, K = multiple of 32. Hot loops should pre-pack B +//! ([`PackedBf16B`]) or stage it directly in VNNI layout +//! ([`PackedBf16B::vnni_index`]) and call [`bf16_tile_gemm_16x16_packed`] — +//! the row-major entry [`bf16_tile_gemm_16x16`] VNNI-packs (and allocates) +//! per call on the tile tiers. //! -//! Usage: -//! ```ignore -//! use ndarray::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16; -//! let mut c = vec![0.0f32; 16*16]; -//! bf16_tile_gemm_16x16(&a_bf16, &b_bf16_row_major, &mut c, k); -//! ``` +//! ⚠ Gotcha 14 (`.claude/AMX_GOTCHAS.md`): on oversubscribed VMs, AMX tile +//! state silently corrupts under host CPU contention. Certify tier-1 +//! numerics on dedicated silicon; see `tile_parity_under_cpu_contention` +//! (`--ignored`). use crate::hpc::amx_matmul::{ amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16, @@ -29,28 +37,246 @@ use crate::hpc::amx_matmul::{ /// Compute C[16, 16] += A[16, K] × B[K, 16] where A, B are BF16 row-major /// and C is f32 row-major. K must be a multiple of 32. /// -/// Tier dispatch (runtime): -/// AMX available → TDPBF16PS tile GEMM (16×16 × K/32 tile iterations) -/// AMX unavailable → AVX-512 F32x16 FMA fallback (decode BF16→f32, gemm) -/// -/// Both paths produce identical results up to BF16 precision (~1/128 per -/// multiply, O(sqrt(K)) accumulated). +/// Tier dispatch (runtime): AMX `TDPBF16PS` → AVX-512 `VDPBF16PS` → decode + +/// `F32x16` FMA polyfill (see module doc). Bit-exact across tiers for +/// bf16-exact integer operands; identical up to BF16-precision accumulation +/// order otherwise. On the tile tiers this entry VNNI-packs B (and +/// allocates) per call — hot loops should use [`bf16_tile_gemm_16x16_packed`]. pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { assert_eq!(k % 32, 0, "K must be multiple of 32"); assert_eq!(a_bf16.len(), 16 * k); assert_eq!(b_bf16.len(), k * 16); assert_eq!(c.len(), 16 * 16); - if amx_available() { - // AMX path: pack B into VNNI, call tile GEMM + if amx_available() || is_x86_feature_detected!("avx512bf16") { + // Tile tiers want VNNI-packed B; pack per call (see `PackedBf16B` + + // `bf16_tile_gemm_16x16_packed` to hoist this out of hot loops). let mut b_vnni = vec![0u16; k * 16]; vnni_pack_bf16(b_bf16, &mut b_vnni, k, 16); - // SAFETY: amx_available() just confirmed CPUID + XCR0 + prctl. + dispatch_vnni(a_bf16, &b_vnni, c, k); + } else { + // Pure-FMA tier consumes row-major B directly — no pack needed. + fallback_path(a_bf16, b_bf16, c, k); + } +} + +/// Which kernel `bf16_tile_gemm_16x16` / `_packed` will take on this host, +/// as a human-readable tier name. For run reports per AMX Gotcha 9 ("a +/// skipped test is not a passing test" — always PRINT which tier ran). +pub fn bf16_tile_gemm_tier() -> &'static str { + if amx_available() { + "AMX TDPBF16PS" + } else if is_x86_feature_detected!("avx512bf16") { + "AVX-512 VDPBF16PS" + } else { + "F32x16 FMA polyfill (bf16→f32 decode)" + } +} + +// ═════════════════════════════════════════════════════════════════════ +// Pre-packed B — hoists VNNI packing (and its allocation) out of hot loops +// ═════════════════════════════════════════════════════════════════════ + +/// B\[K, 16\] operand held in the VNNI pair layout shared by AMX `TDPBF16PS` +/// and AVX-512 `VDPBF16PS`: pair-row `i` holds `(B[2i, j], B[2i+1, j])` +/// interleaved over `j` — 32 bf16 = 64 bytes = one tile row / one zmm. +/// +/// Two ways to fill it: +/// - [`PackedBf16B::pack`] / [`PackedBf16B::pack_from`] — from row-major B +/// (`pack_from` reuses the buffer: no per-call allocation). +/// - Stage **directly** in VNNI layout via [`PackedBf16B::vnni_index`] + +/// [`PackedBf16B::data_mut`] — zero packing cost. This is the right shape +/// for sparse staging (e.g. one-hot group-by indicators): write the few +/// live entries, run the GEMM, clear the same entries. +/// +/// # Examples +/// ``` +/// use ndarray::simd::{bf16_tile_gemm_16x16_packed, PackedBf16B}; +/// let k = 32; +/// let mut b = PackedBf16B::zeroed(k); +/// // B[5][3] = 1.0 staged directly in VNNI layout (0x3F80 = bf16 1.0): +/// let idx = PackedBf16B::vnni_index(5, 3); +/// b.data_mut()[idx] = 0x3F80; +/// let a = vec![0x3F80u16; 16 * k]; // A = all 1.0 +/// let mut c = vec![0.0f32; 256]; +/// bf16_tile_gemm_16x16_packed(&a, &b, &mut c); +/// assert_eq!(c[3], 1.0); // C[0][3] = Σ_k A[0][k]·B[k][3] = A[0][5]·1.0 +/// ``` +pub struct PackedBf16B { + data: Vec, + k: usize, +} + +impl PackedBf16B { + /// All-zero packed B for direct VNNI staging. `k` must be a multiple of 32. + pub fn zeroed(k: usize) -> Self { + assert_eq!(k % 32, 0, "K must be multiple of 32"); + PackedBf16B { + data: vec![0u16; k * 16], + k, + } + } + + /// Pack row-major B\[K, 16\] (allocates once; prefer [`Self::pack_from`] + /// for repeated packs). + pub fn pack(b_bf16_row_major: &[u16], k: usize) -> Self { + let mut p = Self::zeroed(k); + p.pack_from(b_bf16_row_major); + p + } + + /// Re-pack row-major B\[K, 16\] into the existing buffer — no allocation. + pub fn pack_from(&mut self, b_bf16_row_major: &[u16]) { + assert_eq!(b_bf16_row_major.len(), self.k * 16); + vnni_pack_bf16(b_bf16_row_major, &mut self.data, self.k, 16); + } + + /// Flat index of logical `B[row, col]` inside the VNNI buffer. + /// `vnni[(row/2)·32 + 2·col + (row & 1)] == B[row, col]`. + #[inline(always)] + pub const fn vnni_index(row: usize, col: usize) -> usize { + (row / 2) * 32 + 2 * col + (row & 1) + } + + /// The K dimension this buffer was sized for. + pub fn k(&self) -> usize { + self.k + } + + /// Read access to the raw VNNI buffer. + pub fn data(&self) -> &[u16] { + &self.data + } + + /// Mutable access to the raw VNNI buffer, for direct staging via + /// [`Self::vnni_index`]. + pub fn data_mut(&mut self) -> &mut [u16] { + &mut self.data + } + + // ── Little-endian byte contract ───────────────────────────────────── + // + // The persistence/mailbox face of the buffer (the lance-graph + // SoaEnvelope discipline: envelope bytes are LE from creation to + // tombstone). This module is `cfg(target_arch = "x86_64")`, an LE-only + // ISA, so the native `u16` lanes ARE the LE bytes — the view below is + // a zero-cost reinterpret, and the contract costs nothing to state + // explicitly. The read side decodes with `u16::from_le_bytes`, which + // is endian-correct anywhere and compiles to a plain copy here. + + /// The buffer as **little-endian** bytes — `2·k·16` bytes, each bf16 + /// lane low-byte-first. Zero-copy on this (LE-only) architecture. This + /// is the face a batch writer hands to columnar storage; pair with + /// [`Self::from_le_bytes`] to round-trip. + pub fn as_le_bytes(&self) -> &[u8] { + // SAFETY: `&[u16]` → `&[u8]` reinterpret is always valid (alignment + // requirement shrinks from 2 to 1; length doubles within the same + // allocation). Byte order is LE by construction: x86_64 is + // little-endian and this module is x86_64-only. + unsafe { core::slice::from_raw_parts(self.data.as_ptr() as *const u8, self.data.len() * 2) } + } + + /// Rebuild from **little-endian** bytes (the inverse of + /// [`Self::as_le_bytes`]). `bytes.len()` must be `2·k·16` and `k` a + /// multiple of 32. Endian-correct on any architecture; compiles to a + /// straight copy on LE targets. + pub fn from_le_bytes(bytes: &[u8], k: usize) -> Self { + assert_eq!(k % 32, 0, "K must be multiple of 32"); + assert_eq!(bytes.len(), 2 * k * 16, "expected 2·K·16 LE bytes"); + let data = bytes + .chunks_exact(2) + .map(|p| u16::from_le_bytes([p[0], p[1]])) + .collect(); + PackedBf16B { data, k } + } +} + +/// `C[16, 16] += A[16, K] × B[K, 16]` with B pre-packed — the hot-loop +/// sibling of [`bf16_tile_gemm_16x16`] with the per-call VNNI pack (and its +/// allocation) hoisted into [`PackedBf16B`]. +/// +/// Tier dispatch (runtime): AMX `TDPBF16PS` → AVX-512 `VDPBF16PS` → decode + +/// `F32x16` FMA polyfill. All tiers accumulate into `c` and are exact for +/// bf16-exact-integer operands (products of bf16 values are exact in f32; +/// integer accumulation is exact below 2^24). Use +/// [`bf16_tile_gemm_tier`] to report which tier ran. +pub fn bf16_tile_gemm_16x16_packed(a_bf16: &[u16], b: &PackedBf16B, c: &mut [f32]) { + let k = b.k; + assert_eq!(a_bf16.len(), 16 * k); + assert_eq!(c.len(), 16 * 16); + dispatch_vnni(a_bf16, &b.data, c, k); +} + +/// Shared tier dispatch over a VNNI-packed B. +fn dispatch_vnni(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { + if amx_available() { + // SAFETY: amx_available() confirmed CPUID + XCR0 + arch_prctl grant. unsafe { - amx_path(a_bf16, &b_vnni, c, k); + amx_path(a_bf16, b_vnni, c, k); + } + } else if is_x86_feature_detected!("avx512bf16") { + // SAFETY: runtime detection confirmed avx512f+avx512bf16. + unsafe { + avx512bf16_path(a_bf16, b_vnni, c, k); } } else { - fallback_path(a_bf16, b_bf16, c, k); + // Slow tier from a packed buffer: unpack back to row-major and take + // the decode+FMA polyfill. Allocates — acceptable off the tile tiers + // (a host without AMX and without avx512bf16 calling the packed entry + // is not a hot-path configuration). + let mut b_rm = vec![0u16; k * 16]; + for row in 0..k { + for col in 0..16 { + b_rm[row * 16 + col] = b_vnni[PackedBf16B::vnni_index(row, col)]; + } + } + fallback_path(a_bf16, &b_rm, c, k); + } +} + +// ═════════════════════════════════════════════════════════════════════ +// AVX-512 BF16 path (VDPBF16PS) — the no-decode middle tier +// ═════════════════════════════════════════════════════════════════════ + +/// AVX-512 BF16 GEMM over VNNI-packed B — no bf16→f32 decode roundtrip: +/// `VDPBF16PS` multiplies bf16 pairs natively, accumulating into f32 lanes +/// (same accumulator semantics as the AMX tile, register-sized instead of +/// tile-sized). Per output row: one zmm accumulator preloaded from `c` +/// (accumulate semantics), K/2 pair steps, one store. +/// +/// Numerics: bf16×bf16 products are exact in f32; accumulation order is +/// row-major over pair index (matches the AMX per-row order). VDPBF16PS +/// flushes input denormals to zero — irrelevant for the integer workloads +/// this crate certifies, documented for float callers. +/// +/// # Safety +/// Caller must have verified `is_x86_feature_detected!("avx512bf16")`. +#[target_feature(enable = "avx512f,avx512bf16")] +unsafe fn avx512bf16_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { + use core::arch::x86_64::{ + __m512bh, __m512i, _mm512_dpbf16_ps, _mm512_loadu_ps, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_storeu_ps, + }; + debug_assert_eq!(k % 2, 0); + let pairs = k / 2; + for i in 0..16 { + // SAFETY (whole block): a_bf16 is 16×k and c is 16×16 (asserted by + // the public callers), b_vnni is k×16 pairs; every pointer below + // stays inside those bounds. read_unaligned handles the u16→u32 + // pair load without alignment requirements. + unsafe { + let mut acc = _mm512_loadu_ps(c.as_ptr().add(i * 16)); + let a_row = a_bf16.as_ptr().add(i * k); + for p in 0..pairs { + // Broadcast the A pair (a[2p], a[2p+1]) to all 16 lanes; lane j + // then accumulates A[i][2p]·B[2p][j] + A[i][2p+1]·B[2p+1][j]. + let pair = (a_row.add(2 * p) as *const u32).read_unaligned(); + let av: __m512bh = core::mem::transmute(_mm512_set1_epi32(pair as i32)); + let bv: __m512bh = + core::mem::transmute(_mm512_loadu_si512(b_vnni.as_ptr().add(p * 32) as *const __m512i)); + acc = _mm512_dpbf16_ps(acc, av, bv); + } + _mm512_storeu_ps(c.as_mut_ptr().add(i * 16), acc); + } } } @@ -188,4 +414,199 @@ mod tests { assert_eq!(*v, 0.0); } } + + /// bf16 encode by truncation — exact for the small integers used below. + fn bf16_int(v: i32) -> u16 { + let f = v as f32; + debug_assert_eq!(f.to_bits() & 0xFFFF, 0); + (f.to_bits() >> 16) as u16 + } + + /// Deterministic small-integer operands (all bf16-exact) + i64 reference. + /// Every tier must reproduce the reference EXACTLY on this input class. + fn integer_case(k: usize) -> (Vec, Vec, Vec) { + let mut a = vec![0u16; 16 * k]; + let mut b = vec![0u16; k * 16]; + let mut s = 0x1BC0FFEEu64; + let mut next = move || { + s = s + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + ((s >> 33) % 15) as i32 - 7 // integers in [-7, 7] + }; + for v in a.iter_mut() { + *v = bf16_int(next()); + } + for v in b.iter_mut() { + *v = bf16_int(next()); + } + // i64 reference on the decoded integers + let dec = |u: u16| f32::from_bits((u as u32) << 16) as i64; + let mut c_ref = vec![0.0f32; 256]; + for i in 0..16 { + for j in 0..16 { + let mut acc = 0i64; + for kk in 0..k { + acc += dec(a[i * k + kk]) * dec(b[kk * 16 + j]); + } + c_ref[i * 16 + j] = acc as f32; // |acc| ≤ k·49 ≪ 2^24 → exact + } + } + (a, b, c_ref) + } + + #[test] + fn vnni_index_matches_vnni_pack_bf16() { + let k = 64; + let (_, b, _) = integer_case(k); + let packed = PackedBf16B::pack(&b, k); + for row in 0..k { + for col in 0..16 { + assert_eq!( + packed.data()[PackedBf16B::vnni_index(row, col)], + b[row * 16 + col], + "vnni_index disagrees with vnni_pack_bf16 at [{row},{col}]" + ); + } + } + } + + #[test] + fn packed_and_unpacked_exact_on_integers() { + // Whatever tier this host dispatches to, integer operands must come + // out EXACT and identical between the packed and unpacked entries. + let k = 64; + let (a, b, c_ref) = integer_case(k); + let mut c_unpacked = vec![0.0f32; 256]; + bf16_tile_gemm_16x16(&a, &b, &mut c_unpacked, k); + let packed = PackedBf16B::pack(&b, k); + let mut c_packed = vec![0.0f32; 256]; + bf16_tile_gemm_16x16_packed(&a, &packed, &mut c_packed); + assert_eq!(c_unpacked, c_ref, "unpacked tier [{}] not exact", bf16_tile_gemm_tier()); + assert_eq!(c_packed, c_ref, "packed tier [{}] not exact", bf16_tile_gemm_tier()); + } + + #[test] + fn le_byte_view_roundtrips_and_is_truly_le() { + let k = 64; + let (_, b, _) = integer_case(k); + let packed = PackedBf16B::pack(&b, k); + + // Contract: byte 2i is the LOW byte of lane i (little-endian), + // independent of how the platform happens to lay out u16. + let bytes = packed.as_le_bytes(); + assert_eq!(bytes.len(), 2 * k * 16); + for (i, &lane) in packed.data().iter().enumerate() { + assert_eq!([bytes[2 * i], bytes[2 * i + 1]], lane.to_le_bytes(), "lane {i} not LE in byte view"); + } + + // Round-trip: bytes → PackedBf16B → identical lanes AND identical + // GEMM result (the envelope face and the compute face agree). + let rebuilt = PackedBf16B::from_le_bytes(bytes, k); + assert_eq!(rebuilt.data(), packed.data()); + let (a, _, c_ref) = integer_case(k); + let mut c = vec![0.0f32; 256]; + bf16_tile_gemm_16x16_packed(&a, &rebuilt, &mut c); + assert_eq!(c, c_ref, "GEMM over LE-roundtripped buffer not exact"); + } + + #[test] + fn packed_entry_accumulates() { + // C += A·B semantics: pre-existing C values must survive. + let k = 32; + let (a, b, c_ref) = integer_case(k); + let packed = PackedBf16B::pack(&b, k); + let mut c = vec![5.0f32; 256]; + bf16_tile_gemm_16x16_packed(&a, &packed, &mut c); + for i in 0..256 { + assert_eq!(c[i], c_ref[i] + 5.0, "accumulate semantics broken at {i}"); + } + } + + #[test] + fn avx512bf16_path_matches_fallback() { + if !is_x86_feature_detected!("avx512bf16") { + eprintln!("SKIP (honestly): no avx512bf16 on this host — Gotcha 9: a skipped test is not a passing test"); + return; + } + let k = 64; + // Exact-integer parity: bit-for-bit. + let (a, b, c_ref) = integer_case(k); + let packed = PackedBf16B::pack(&b, k); + let mut c = vec![0.0f32; 256]; + // SAFETY: detection checked above. + unsafe { avx512bf16_path(&a, packed.data(), &mut c, k) }; + assert_eq!(c, c_ref, "VDPBF16PS not exact on integer operands"); + + // Float parity vs the decode+FMA fallback: same products (exact), + // different accumulation order → tolerance-compare. + let mut a_f = vec![0.0f32; 16 * k]; + let mut b_f = vec![0.0f32; k * 16]; + for (i, v) in a_f.iter_mut().enumerate() { + *v = ((i as f32) * 0.37).sin(); + } + for (i, v) in b_f.iter_mut().enumerate() { + *v = ((i as f32) * 0.73).cos(); + } + let mut a_bf = vec![0u16; a_f.len()]; + let mut b_bf = vec![0u16; b_f.len()]; + f32_to_bf16_batch(&a_f, &mut a_bf); + f32_to_bf16_batch(&b_f, &mut b_bf); + let packed_f = PackedBf16B::pack(&b_bf, k); + let mut c_v = vec![0.0f32; 256]; + // SAFETY: detection checked above. + unsafe { avx512bf16_path(&a_bf, packed_f.data(), &mut c_v, k) }; + let mut c_fb = vec![0.0f32; 256]; + fallback_path(&a_bf, &b_bf, &mut c_fb, k); + for i in 0..256 { + assert!((c_v[i] - c_fb[i]).abs() < 1e-3, "VDPBF16PS vs fallback at {i}: {} vs {}", c_v[i], c_fb[i]); + } + } + + /// Gotcha 14 parity-under-load gate. IGNORED by default: on an + /// oversubscribed VM this FAILS BY DESIGN (AMX tile state silently + /// corrupts under host CPU contention — see `.claude/AMX_GOTCHAS.md` + /// § Gotcha 14). Run explicitly on dedicated silicon to certify: + /// cargo test -p ndarray --release bf16_tile_gemm -- --ignored + #[test] + #[ignore = "Gotcha 14: fails on oversubscribed VMs by design — run on dedicated silicon"] + fn tile_parity_under_cpu_contention() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + if !amx_available() && !is_x86_feature_detected!("avx512bf16") { + eprintln!("SKIP (honestly): no tile/vector-bf16 tier on this host"); + return; + } + let stop = Arc::new(AtomicBool::new(false)); + let busy: Vec<_> = (0..3) + .map(|_| { + let s = Arc::clone(&stop); + std::thread::spawn(move || { + let mut x = 0u64; + while !s.load(Ordering::Relaxed) { + x = x.wrapping_mul(6364136223846793005).wrapping_add(1); + std::hint::black_box(x); + } + }) + }) + .collect(); + // Large K on purpose: long tile residency (K=4096 holds tmm0 for + // 128 TDPBF16PS iterations) maximizes the Gotcha-14 exposure window. + let k = 4096; + let (a, b, c_ref) = integer_case(k); + let packed = PackedBf16B::pack(&b, k); + let mut bad = 0usize; + for _ in 0..200 { + let mut c = vec![0.0f32; 256]; + bf16_tile_gemm_16x16_packed(&a, &packed, &mut c); + if c != c_ref { + bad += 1; + } + } + stop.store(true, Ordering::Relaxed); + for h in busy { + let _ = h.join(); + } + assert_eq!(bad, 0, "{bad}/200 GEMMs corrupted under load on tier [{}]", bf16_tile_gemm_tier()); + } } diff --git a/src/simd.rs b/src/simd.rs index 30a37dff..4bec0d94 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -614,6 +614,20 @@ pub use crate::hpc::heel_f64x8::cosine_f32_to_f64_simd; // `backend::gemm_bf16` (portable scalar / NEON / wasm-SIMD paths). #[cfg(all(feature = "std", target_arch = "x86_64"))] pub use crate::hpc::amx_matmul::{amx_available, matmul_i8_to_i32}; +// Tile-dispatching sibling of the polyfill `bf16_tile_gemm_16x16` below: +// AMX TDPBF16PS → AVX-512 VDPBF16PS → the same FMA polyfill kernel, selected +// at runtime. Same W1a rationale as `matmul_i8_to_i32` — consumers reach the +// tile ladder through `ndarray::simd::*`; the `_amx` suffix keeps the +// pure-polyfill kernel and the tile-dispatching wrapper distinguishable at +// the call site. `bf16_tile_gemm_16x16_packed` + `PackedBf16B` hoist the +// VNNI pack (and its allocation) out of hot loops — `PackedBf16B::vnni_index` +// additionally supports staging B directly in VNNI layout (zero pack cost). +// `bf16_tile_gemm_tier()` names the tier that will run, for Gotcha-9-style +// run reports. +#[cfg(all(feature = "std", target_arch = "x86_64"))] +pub use crate::hpc::bf16_tile_gemm::{ + bf16_tile_gemm_16x16 as bf16_tile_gemm_16x16_amx, bf16_tile_gemm_16x16_packed, bf16_tile_gemm_tier, PackedBf16B, +}; // CPU-generation detection (cached): SPR / EMR / GNR / Sierra Forest. Lets a // consumer report which silicon a run landed on and distinguish "no AMX // silicon" from "AMX present but not OS-enabled" — both surface via `amx_report`.