@@ -21,7 +21,6 @@ use crate::hpc::amx_matmul::{
2121 amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16,
2222 TileConfig ,
2323} ;
24- use crate :: simd:: { bf16_to_f32_batch, F32x16 } ;
2524
2625// ═════════════════════════════════════════════════════════════════════
2726// Public API — safe dispatching wrapper
@@ -104,39 +103,12 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) {
104103// AVX-512 fallback (F32x16 + mul_add FMA)
105104// ═════════════════════════════════════════════════════════════════════
106105
107- /// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA.
108- /// When AVX-512 is the compile-time baseline, this uses native __m512 FMA;
109- /// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic.
106+ /// Fallback: delegate to the single source-of-truth SIMD-polyfill kernel
107+ /// [`crate::simd::bf16_tile_gemm_16x16`] (BF16→f32 decode + `F32x16` FMA). The
108+ /// `F32x16` wrapper owns the AVX-512 / AVX2 / NEON / scalar dispatch, so this
109+ /// AMX wrapper only adds the TDPBF16PS tile path on top of the same kernel.
110110fn fallback_path ( a_bf16 : & [ u16 ] , b_bf16 : & [ u16 ] , c : & mut [ f32 ] , k : usize ) {
111- // Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available)
112- let mut a_f32 = vec ! [ 0.0f32 ; a_bf16. len( ) ] ;
113- let mut b_f32 = vec ! [ 0.0f32 ; b_bf16. len( ) ] ;
114- bf16_to_f32_batch ( a_bf16, & mut a_f32) ;
115- bf16_to_f32_batch ( b_bf16, & mut b_f32) ;
116-
117- // Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA.
118- // B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K.
119- // We gather the column into a stack-sized buffer once per (i,j) pair to hit
120- // the chunks_exact(16) + mul_add fast path on contiguous memory.
121- for i in 0 ..16 {
122- let a_row = & a_f32[ i * k..i * k + k] ;
123- for j in 0 ..16 {
124- // Stream the column into a contiguous buffer
125- let mut col = vec ! [ 0.0f32 ; k] ;
126- for kk in 0 ..k {
127- col[ kk] = b_f32[ kk * 16 + j] ;
128- }
129-
130- // Accumulate via F32x16::mul_add (FMA)
131- let mut acc = F32x16 :: splat ( 0.0 ) ;
132- for ( ra, rb) in a_row. chunks_exact ( 16 ) . zip ( col. chunks_exact ( 16 ) ) {
133- let av = F32x16 :: from_slice ( ra) ;
134- let bv = F32x16 :: from_slice ( rb) ;
135- acc = av. mul_add ( bv, acc) ;
136- }
137- c[ i * 16 + j] += acc. reduce_sum ( ) ;
138- }
139- }
111+ crate :: simd:: bf16_tile_gemm_16x16 ( a_bf16, b_bf16, c, k) ;
140112}
141113
142114// ═════════════════════════════════════════════════════════════════════
0 commit comments