diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 0000000..7da17ba --- /dev/null +++ b/demo/README.md @@ -0,0 +1,56 @@ +# GPT-2 Demo + +Text generation with GPT-2-124M using a WarpForth-compiled attention kernel. The stock Hugging Face model is loaded normally, then `eager_attention_forward` is monkey-patched to route scaled dot-product attention through a WarpForth kernel compiled to PTX. PyCUDA shares PyTorch's CUDA context via `autoprimaryctx`, so device pointers pass directly between the two — no copies, no CPU roundtrips. + +## Prerequisites + +- WarpForth built locally (`cmake --build build`) +- A Vast.ai GPU instance with a PyTorch image (e.g. `pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime`) + +## Step 1: Compile the Kernel (Local) + +```bash +./build/bin/warpforthc demo/attention.forth > demo/attention.ptx +``` + +A pre-compiled `attention.ptx` is included in this directory. + +## Step 2: Upload to GPU Instance + +```bash +scp -r demo/ demo/gpt2_generate.py root@HOST:/workspace +``` + +## Step 3: Install Dependencies (Remote) + +```bash +pip install pycuda transformers +``` + +## Step 4: Generate Text (Remote) + +```bash +python /workspace/gpt2_generate.py --ptx /workspace/attention.ptx --prompt "The meaning of life is" +``` + +| Flag | Default | Description | +|------|---------|-------------| +| `--ptx` | (required) | Path to compiled `attention.ptx` | +| `--prompt` | `"The meaning of life is"` | Input text prompt | +| `--max-tokens` | `100` | Maximum new tokens to generate | + +## Limitations + +- **Batch size 1** — the kernel processes one sequence at a time +- **No KV cache** — all positions are recomputed each step (`use_cache=False`) +- **Max sequence length 1024** — limited by shared memory allocation +- **12 kernel launches per layer** — one per attention head + +## Files + +| File | Description | +|------|-------------| +| `attention.forth` | Attention kernel source (f32 global, f64 shared) | +| `attention.ptx` | Pre-compiled PTX | +| `warpforth.py` | PyCUDA wrapper for loading and launching the kernel | +| `gpt2_generate.py` | Loads GPT-2, patches attention, generates text | diff --git a/demo/attention.forth b/demo/attention.forth new file mode 100644 index 0000000..0411db5 --- /dev/null +++ b/demo/attention.forth @@ -0,0 +1,81 @@ +\ GPT-2 attention kernel with f32 global memory, f64 shared memory for softmax. +\ Adapted from test/Pipeline/attention.forth — 4 lines changed for f32 access. +\ +\ Q/K/V/O are f32 arrays passed as raw byte buffers (i64[]). +\ Global loads/stores use F32@/F32! with 4* byte addressing (f32 = 4 bytes). +\ Shared memory stays f64 for softmax precision, using SF@/SF! with CELLS. + +\! kernel attention +\! param Q i64[32768] +\! param K i64[32768] +\! param V i64[32768] +\! param O i64[32768] +\! param SEQ_LEN i64 +\! param HEAD_DIM i64 +\! shared SCORES f64[1024] +\! shared SCRATCH f64[1024] + +\ row = BID-X, t = TID-X +BID-X +TID-X + +\ --- Dot product: Q[row,:] . K[t,:] --- +0.0 +HEAD_DIM 0 DO + 2 PICK HEAD_DIM * I + 4 * Q + F32@ + 2 PICK HEAD_DIM * I + 4 * K + F32@ + F* F+ +LOOP +HEAD_DIM S>F FSQRT F/ + +\ --- Causal mask: if t > row, score = -inf --- +OVER 3 PICK > +IF DROP -1.0e30 THEN + +\ --- Store score to shared memory --- +OVER CELLS SCORES + SF! +BARRIER + +\ --- Softmax: max reduction (thread 0) --- +TID-X 0= IF + 0 CELLS SCORES + SF@ + SEQ_LEN 1 DO I CELLS SCORES + SF@ FMAX LOOP + 0 CELLS SCRATCH + SF! +THEN +BARRIER + +\ --- Softmax: exp(score - max) --- +DUP CELLS SCORES + SF@ +0 CELLS SCRATCH + SF@ +F- FEXP +OVER CELLS SCORES + SF! +BARRIER + +\ --- Softmax: sum reduction (thread 0) --- +TID-X 0= IF + 0.0 + SEQ_LEN 0 DO I CELLS SCORES + SF@ F+ LOOP + 0 CELLS SCRATCH + SF! +THEN +BARRIER + +\ --- Softmax: normalize --- +DUP CELLS SCORES + SF@ +0 CELLS SCRATCH + SF@ +F/ +OVER CELLS SCORES + SF! +BARRIER + +\ --- V accumulation: O[row,col] = sum_j SCORES[j] * V[j*HD + col] --- +\ Stride over head_dim columns: col = t, t+BDIM-X, t+2*BDIM-X, ... +DUP BEGIN DUP HEAD_DIM < WHILE + 0.0 + SEQ_LEN 0 DO + I CELLS SCORES + SF@ + I HEAD_DIM * 3 PICK + 4 * V + F32@ + F* F+ + LOOP + OVER 4 PICK HEAD_DIM * + 4 * O + F32! + BDIM-X + +REPEAT +DROP DROP DROP diff --git a/demo/attention.ptx b/demo/attention.ptx new file mode 100644 index 0000000..8bb4cb4 --- /dev/null +++ b/demo/attention.ptx @@ -0,0 +1,224 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 6.0 +.target sm_70 +.address_size 64 + + // .globl attention +// __wg_attention_0 has been demoted +// __wg_attention_1_$_0 has been demoted + +.visible .entry attention( + .param .u64 .ptr .align 1 attention_param_0, + .param .u64 .ptr .align 1 attention_param_1, + .param .u64 .ptr .align 1 attention_param_2, + .param .u64 .ptr .align 1 attention_param_3, + .param .u64 attention_param_4, + .param .u64 attention_param_5 +) +{ + .reg .pred %p<17>; + .reg .b32 %r<21>; + .reg .f32 %f<7>; + .reg .b64 %rd<76>; + .reg .f64 %fd<63>; + // demoted variable + .shared .align 8 .b8 __wg_attention_0[8192]; + // demoted variable + .shared .align 8 .f64 __wg_attention_1_$_0; + ld.param.u64 %rd41, [attention_param_5]; + ld.param.u64 %rd40, [attention_param_4]; + ld.param.u64 %rd39, [attention_param_3]; + ld.param.u64 %rd38, [attention_param_2]; + ld.param.u64 %rd42, [attention_param_0]; + mov.u32 %r1, %ctaid.x; + ld.param.u64 %rd43, [attention_param_1]; + cvt.u64.u32 %rd44, %r1; + mov.u32 %r5, %tid.x; + cvt.u64.u32 %rd72, %r5; + mul.lo.s64 %rd2, %rd41, %rd44; + mul.lo.s64 %rd45, %rd41, %rd72; + neg.s64 %rd66, %rd41; + shl.b64 %rd46, %rd45, 2; + add.s64 %rd65, %rd43, %rd46; + shl.b64 %rd47, %rd2, 2; + add.s64 %rd64, %rd42, %rd47; + mov.f64 %fd58, 0d0000000000000000; +$L__BB0_1: + ld.f32 %f2, [%rd64]; + cvt.f64.f32 %fd16, %f2; + ld.f32 %f3, [%rd65]; + cvt.f64.f32 %fd17, %f3; + mul.rn.f64 %fd18, %fd16, %fd17; + add.rn.f64 %fd58, %fd58, %fd18; + add.s64 %rd9, %rd66, 1; + xor.b64 %rd48, %rd9, %rd66; + setp.gt.s64 %p1, %rd48, -1; + add.s64 %rd65, %rd65, 4; + add.s64 %rd64, %rd64, 4; + mov.u64 %rd66, %rd9; + @%p1 bra $L__BB0_1; + cvt.u32.u64 %r6, %rd72; + cvt.rn.f64.s64 %fd19, %rd41; + sqrt.rn.f64 %fd20, %fd19; + div.rn.f64 %fd21, %fd58, %fd20; + setp.gt.u32 %p2, %r6, %r1; + shl.b64 %rd49, %rd72, 3; + mov.u64 %rd50, __wg_attention_0; + add.s64 %rd12, %rd49, %rd50; + selp.f64 %fd22, 0dC6293E5939A08CEA, %fd21, %p2; + st.shared.f64 [%rd12], %fd22; + bar.sync 0; + setp.ne.s32 %p3, %r6, 0; + @%p3 bra $L__BB0_6; + ld.shared.f64 %fd60, [__wg_attention_0]; + mov.u64 %rd52, __wg_attention_0; + add.s64 %rd68, %rd52, 8; + sub.s64 %rd67, 1, %rd40; +$L__BB0_4: + ld.shared.f64 %fd23, [%rd68]; + mov.b64 %rd53, %fd23; + setp.nan.f64 %p4, %fd60, %fd23; + max.f64 %fd24, %fd60, %fd23; + selp.f64 %fd25, 0d7FF8000000000000, %fd24, %p4; + mov.b64 %rd54, %fd60; + setp.eq.s64 %p5, %rd54, 0; + selp.f64 %fd26, %fd60, %fd25, %p5; + setp.eq.s64 %p6, %rd53, 0; + selp.f64 %fd27, %fd23, %fd26, %p6; + setp.eq.f64 %p7, %fd25, 0d0000000000000000; + selp.f64 %fd60, %fd27, %fd25, %p7; + add.s64 %rd17, %rd67, 1; + xor.b64 %rd55, %rd17, %rd67; + setp.gt.s64 %p8, %rd55, -1; + add.s64 %rd68, %rd68, 8; + mov.u64 %rd67, %rd17; + @%p8 bra $L__BB0_4; + st.shared.f64 [__wg_attention_1_$_0], %fd60; +$L__BB0_6: + bar.sync 0; + ld.shared.f64 %fd28, [%rd12]; + ld.shared.f64 %fd29, [__wg_attention_1_$_0]; + sub.rn.f64 %fd4, %fd28, %fd29; + fma.rn.f64 %fd30, %fd4, 0d3FF71547652B82FE, 0d4338000000000000; + { + .reg .b32 %temp; + mov.b64 {%r2, %temp}, %fd30; + } + mov.f64 %fd31, 0dC338000000000000; + add.rn.f64 %fd32, %fd30, %fd31; + fma.rn.f64 %fd33, %fd32, 0dBFE62E42FEFA39EF, %fd4; + fma.rn.f64 %fd34, %fd32, 0dBC7ABC9E3B39803F, %fd33; + fma.rn.f64 %fd35, %fd34, 0d3E5ADE1569CE2BDF, 0d3E928AF3FCA213EA; + fma.rn.f64 %fd36, %fd35, %fd34, 0d3EC71DEE62401315; + fma.rn.f64 %fd37, %fd36, %fd34, 0d3EFA01997C89EB71; + fma.rn.f64 %fd38, %fd37, %fd34, 0d3F2A01A014761F65; + fma.rn.f64 %fd39, %fd38, %fd34, 0d3F56C16C1852B7AF; + fma.rn.f64 %fd40, %fd39, %fd34, 0d3F81111111122322; + fma.rn.f64 %fd41, %fd40, %fd34, 0d3FA55555555502A1; + fma.rn.f64 %fd42, %fd41, %fd34, 0d3FC5555555555511; + fma.rn.f64 %fd43, %fd42, %fd34, 0d3FE000000000000B; + fma.rn.f64 %fd44, %fd43, %fd34, 0d3FF0000000000000; + fma.rn.f64 %fd45, %fd44, %fd34, 0d3FF0000000000000; + { + .reg .b32 %temp; + mov.b64 {%r3, %temp}, %fd45; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r4}, %fd45; + } + shl.b32 %r7, %r2, 20; + add.s32 %r8, %r4, %r7; + mov.b64 %fd59, {%r3, %r8}; + { + .reg .b32 %temp; + mov.b64 {%temp, %r9}, %fd4; + } + mov.b32 %f4, %r9; + abs.f32 %f1, %f4; + setp.lt.f32 %p9, %f1, 0f4086232B; + @%p9 bra $L__BB0_9; + setp.lt.f64 %p10, %fd4, 0d0000000000000000; + add.rn.f64 %fd46, %fd4, 0d7FF0000000000000; + selp.f64 %fd59, 0d0000000000000000, %fd46, %p10; + setp.geu.f32 %p11, %f1, 0f40874800; + @%p11 bra $L__BB0_9; + shr.u32 %r10, %r2, 31; + add.s32 %r11, %r2, %r10; + shr.s32 %r12, %r11, 1; + shl.b32 %r13, %r12, 20; + add.s32 %r14, %r4, %r13; + mov.b64 %fd47, {%r3, %r14}; + sub.s32 %r15, %r2, %r12; + shl.b32 %r16, %r15, 20; + add.s32 %r17, %r16, 1072693248; + mov.b32 %r18, 0; + mov.b64 %fd48, {%r18, %r17}; + mul.rn.f64 %fd59, %fd48, %fd47; +$L__BB0_9: + st.shared.f64 [%rd12], %fd59; + bar.sync 0; + @%p3 bra $L__BB0_13; + neg.s64 %rd69, %rd40; + mov.f64 %fd61, 0d0000000000000000; + mov.u64 %rd70, __wg_attention_0; +$L__BB0_11: + ld.shared.f64 %fd50, [%rd70]; + add.rn.f64 %fd61, %fd61, %fd50; + add.s64 %rd26, %rd69, 1; + xor.b64 %rd57, %rd26, %rd69; + setp.gt.s64 %p13, %rd57, -1; + add.s64 %rd70, %rd70, 8; + mov.u64 %rd69, %rd26; + @%p13 bra $L__BB0_11; + st.shared.f64 [__wg_attention_1_$_0], %fd61; +$L__BB0_13: + bar.sync 0; + ld.shared.f64 %fd51, [%rd12]; + ld.shared.f64 %fd52, [__wg_attention_1_$_0]; + div.rn.f64 %fd53, %fd51, %fd52; + st.shared.f64 [%rd12], %fd53; + bar.sync 0; + setp.le.s64 %p14, %rd41, %rd72; + @%p14 bra $L__BB0_18; + mov.u32 %r20, %ntid.x; + cvt.u64.u32 %rd19, %r20; + shl.b64 %rd58, %rd72, 2; + add.s64 %rd71, %rd38, %rd58; + mul.wide.u32 %rd21, %r20, 4; + shl.b64 %rd22, %rd41, 2; + neg.s64 %rd23, %rd40; +$L__BB0_15: + mov.f64 %fd62, 0d0000000000000000; + mov.u64 %rd73, %rd23; + mov.u64 %rd74, %rd71; + mov.u64 %rd75, %rd50; +$L__BB0_16: + ld.shared.f64 %fd55, [%rd75]; + ld.f32 %f5, [%rd74]; + cvt.f64.f32 %fd56, %f5; + mul.rn.f64 %fd57, %fd55, %fd56; + add.rn.f64 %fd62, %fd62, %fd57; + add.s64 %rd33, %rd73, 1; + xor.b64 %rd60, %rd33, %rd73; + setp.gt.s64 %p15, %rd60, -1; + add.s64 %rd75, %rd75, 8; + add.s64 %rd74, %rd74, %rd22; + mov.u64 %rd73, %rd33; + @%p15 bra $L__BB0_16; + add.s64 %rd61, %rd72, %rd2; + shl.b64 %rd62, %rd61, 2; + add.s64 %rd63, %rd62, %rd39; + cvt.rn.f32.f64 %f6, %fd62; + st.f32 [%rd63], %f6; + add.s64 %rd72, %rd72, %rd19; + add.s64 %rd71, %rd71, %rd21; + setp.lt.s64 %p16, %rd72, %rd41; + @%p16 bra $L__BB0_15; +$L__BB0_18: + ret; + +} diff --git a/demo/gpt2_generate.py b/demo/gpt2_generate.py new file mode 100644 index 0000000..05b1838 --- /dev/null +++ b/demo/gpt2_generate.py @@ -0,0 +1,100 @@ +"""GPT-2 text generation with WarpForth-compiled attention kernel. + +Loads GPT-2-124M from Hugging Face, monkey-patches the attention mechanism +to use a WarpForth CUDA kernel, and generates text from a prompt. + +Usage: + # Local: compile kernel to PTX + ./build/bin/warpforthc demo/attention.forth > demo/attention.ptx + + # Remote (Vast.ai GPU instance): + pip install pycuda transformers torch + python gpt2_generate.py --ptx attention.ptx --prompt "The meaning of life is" +""" + +from __future__ import annotations + +import argparse + +import torch +import transformers.models.gpt2.modeling_gpt2 as gpt2_module +from transformers import GPT2LMHeadModel, GPT2Tokenizer +from warpforth import AttentionKernel + + +def make_warpforth_eager_attn(attn_kernel: AttentionKernel): + """Create a replacement for eager_attention_forward using the WarpForth kernel. + + The transformers eager_attention_forward signature is: + (module, query, key, value, attention_mask, **kwargs) + query/key/value: (batch, n_heads, seq_len, head_dim) float32 CUDA. + Returns (attn_output, attn_weights) with attn_output transposed to + (batch, seq_len, n_heads, head_dim). + """ + + def warpforth_eager_attn(module, query, key, value, attention_mask=None, **kwargs): + _batch, n_heads, seq_len, head_dim = query.shape + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + attn_output = torch.zeros_like(query) + + for h in range(n_heads): + attn_kernel( + query[0, h], + key[0, h], + value[0, h], + attn_output[0, h], + seq_len, + head_dim, + ) + + return attn_output.transpose(1, 2), None + + return warpforth_eager_attn + + +def main(): + parser = argparse.ArgumentParser(description="GPT-2 generation with WarpForth attention") + parser.add_argument("--ptx", required=True, help="Path to compiled attention.ptx") + parser.add_argument("--prompt", default="The meaning of life is", help="Input prompt") + parser.add_argument("--max-tokens", type=int, default=100, help="Max new tokens to generate") + args = parser.parse_args() + + print("Loading GPT-2 model...") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + model = GPT2LMHeadModel.from_pretrained("gpt2", attn_implementation="eager").cuda().float() + model.eval() + + print(f"Loading WarpForth attention kernel from {args.ptx}") + attn_kernel = AttentionKernel(args.ptx) + + inputs = tokenizer(args.prompt, return_tensors="pt").to("cuda") + print(f"Prompt: {args.prompt!r} ({inputs['input_ids'].shape[1]} tokens)") + + gpt2_module.eager_attention_forward = make_warpforth_eager_attn(attn_kernel) + + print(f"\nGenerating (max {args.max_tokens} new tokens)...") + with torch.no_grad(): + output_ids = model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=args.max_tokens, + use_cache=False, + do_sample=True, + temperature=0.6, + top_p=0.9, + repetition_penalty=1.1, + pad_token_id=tokenizer.eos_token_id, + ) + + output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + print(f"\n{'=' * 60}") + print(output_text) + print(f"{'=' * 60}") + print(f"({output_ids.shape[1]} tokens total)") + + +if __name__ == "__main__": + main() diff --git a/demo/warpforth.py b/demo/warpforth.py new file mode 100644 index 0000000..c72e479 --- /dev/null +++ b/demo/warpforth.py @@ -0,0 +1,49 @@ +"""PyCUDA wrapper for the WarpForth attention kernel. + +Uses pycuda.autoprimaryctx to share PyTorch's CUDA context — device pointers +from torch tensors can be passed directly to kernel launches (zero-copy). +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pycuda.autoprimaryctx # noqa: F401 — activates PyTorch's primary context +import pycuda.driver as cuda + + +class AttentionKernel: + """Loads and launches the WarpForth attention kernel. + + Accepts contiguous float32 CUDA tensors and passes device pointers + directly (zero-copy). Launches one kernel invocation per call. + + The kernel computes: O = softmax(Q @ K^T / sqrt(head_dim)) @ V + with a causal mask, using f64 shared memory for softmax precision. + """ + + def __init__(self, ptx_path: str | Path) -> None: + ptx_bytes = Path(ptx_path).read_bytes() + module = cuda.module_from_buffer(ptx_bytes) + self._function = module.get_function("attention") + + def __call__( + self, + q: object, # torch.Tensor (seq_len, head_dim) float32 CUDA + k: object, # torch.Tensor (seq_len, head_dim) float32 CUDA + v: object, # torch.Tensor (seq_len, head_dim) float32 CUDA + o: object, # torch.Tensor (seq_len, head_dim) float32 CUDA + seq_len: int, + head_dim: int, + ) -> None: + self._function( + np.intp(q.data_ptr()), + np.intp(k.data_ptr()), + np.intp(v.data_ptr()), + np.intp(o.data_ptr()), + np.int64(seq_len), + np.int64(head_dim), + block=(seq_len, 1, 1), + grid=(seq_len, 1, 1), + ) diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index e6e9376..96a08f3 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -2,6 +2,7 @@ from __future__ import annotations +import struct from typing import TYPE_CHECKING import numpy as np @@ -738,3 +739,137 @@ def test_naive_attention_f64_16x64(kernel_runner: KernelRunner) -> None: output_count=n, ) assert result == [pytest.approx(v) for v in expected] + + +# --- Attention (f32 global memory) --- + +_ATTENTION_F32_KERNEL = """\ +\\! kernel attention +\\! param Q i64[{n_i64}] +\\! param K i64[{n_i64}] +\\! param V i64[{n_i64}] +\\! param O i64[{n_i64}] +\\! param SEQ_LEN i64 +\\! param HEAD_DIM i64 +\\! shared SCORES f64[{seq_len}] +\\! shared SCRATCH f64[{seq_len}] +BID-X +TID-X +0.0 +HEAD_DIM 0 DO + 2 PICK HEAD_DIM * I + 4 * Q + F32@ + 2 PICK HEAD_DIM * I + 4 * K + F32@ + F* F+ +LOOP +HEAD_DIM S>F FSQRT F/ +OVER 3 PICK > +IF DROP -1.0e30 THEN +OVER CELLS SCORES + SF! +BARRIER +TID-X 0= IF + 0 CELLS SCORES + SF@ + SEQ_LEN 1 DO I CELLS SCORES + SF@ FMAX LOOP + 0 CELLS SCRATCH + SF! +THEN +BARRIER +DUP CELLS SCORES + SF@ +0 CELLS SCRATCH + SF@ +F- FEXP +OVER CELLS SCORES + SF! +BARRIER +TID-X 0= IF + 0.0 + SEQ_LEN 0 DO I CELLS SCORES + SF@ F+ LOOP + 0 CELLS SCRATCH + SF! +THEN +BARRIER +DUP CELLS SCORES + SF@ +0 CELLS SCRATCH + SF@ +F/ +OVER CELLS SCORES + SF! +BARRIER +DUP BEGIN DUP HEAD_DIM < WHILE + 0.0 + SEQ_LEN 0 DO + I CELLS SCORES + SF@ + I HEAD_DIM * 3 PICK + 4 * V + F32@ + F* F+ + LOOP + OVER 4 PICK HEAD_DIM * + 4 * O + F32! + BDIM-X + +REPEAT +DROP DROP DROP +""" + + +def _pack_f32_to_i64(data: np.ndarray) -> list[int]: + """Pack a flat f64 array as f32 values into i64 list (reinterpret bytes).""" + f32_bytes = data.astype(np.float32).tobytes() + # Pad to 8-byte boundary + pad_len = (8 - len(f32_bytes) % 8) % 8 + f32_bytes += b"\x00" * pad_len + return [struct.unpack(" list[float]: + """Unpack i64 list back to f32 floats.""" + raw = b"".join(struct.pack(" None: + """Naive attention with f32 global memory, f64 shared memory for softmax. + + Same test data as test_naive_attention_f64 but using F32@/F32! for global + memory access. This isolates whether reduced-width float loads/stores work + correctly in the attention kernel. + """ + seq_len, head_dim = 4, 4 + + q = np.array( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0], + ] + ) + k = np.array( + [ + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0], + ] + ) + v = np.array( + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ) + + expected = _attention_reference(q, k, v, seq_len) + n = seq_len * head_dim + n_i64 = (n + 1) // 2 # 2 f32 values per i64 + + result = kernel_runner.run( + forth_source=_ATTENTION_F32_KERNEL.format(n_i64=n_i64, seq_len=seq_len), + params={ + "Q": _pack_f32_to_i64(q.flatten()), + "K": _pack_f32_to_i64(k.flatten()), + "V": _pack_f32_to_i64(v.flatten()), + "SEQ_LEN": seq_len, + "HEAD_DIM": head_dim, + }, + grid=(seq_len, 1, 1), + block=(seq_len, 1, 1), + output_param=3, + output_count=n_i64, + ) + + output = _unpack_i64_to_f32(result, n) + assert output == [pytest.approx(v, rel=1e-5) for v in expected] diff --git a/pyproject.toml b/pyproject.toml index ef4c35d..c3b2f2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,4 @@ ignore = ["D", "COM812", "ISC001"] [tool.ruff.lint.per-file-ignores] "gpu_test/test_*.py" = ["S101", "PLR2004"] "gpu_test/conftest.py" = ["S603", "S607", "PLR0913", "PLR2004"] +"demo/*.py" = ["T201", "INP001", "SLF001", "ANN", "ARG001", "PLR0913", "PLR2004"]