Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# GPT-2 Demo

Text generation with GPT-2-124M using a WarpForth-compiled attention kernel. The stock Hugging Face model is loaded normally, then `eager_attention_forward` is monkey-patched to route scaled dot-product attention through a WarpForth kernel compiled to PTX. PyCUDA shares PyTorch's CUDA context via `autoprimaryctx`, so device pointers pass directly between the two — no copies, no CPU roundtrips.

## Prerequisites

- WarpForth built locally (`cmake --build build`)
- A Vast.ai GPU instance with a PyTorch image (e.g. `pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime`)

## Step 1: Compile the Kernel (Local)

```bash
./build/bin/warpforthc demo/attention.forth > demo/attention.ptx
```

A pre-compiled `attention.ptx` is included in this directory.

## Step 2: Upload to GPU Instance

```bash
scp -r demo/ demo/gpt2_generate.py root@HOST:/workspace
```

## Step 3: Install Dependencies (Remote)

```bash
pip install pycuda transformers
```

## Step 4: Generate Text (Remote)

```bash
python /workspace/gpt2_generate.py --ptx /workspace/attention.ptx --prompt "The meaning of life is"
```

| Flag | Default | Description |
|------|---------|-------------|
| `--ptx` | (required) | Path to compiled `attention.ptx` |
| `--prompt` | `"The meaning of life is"` | Input text prompt |
| `--max-tokens` | `100` | Maximum new tokens to generate |

## Limitations

- **Batch size 1** — the kernel processes one sequence at a time
- **No KV cache** — all positions are recomputed each step (`use_cache=False`)
- **Max sequence length 1024** — limited by shared memory allocation
- **12 kernel launches per layer** — one per attention head

## Files

| File | Description |
|------|-------------|
| `attention.forth` | Attention kernel source (f32 global, f64 shared) |
| `attention.ptx` | Pre-compiled PTX |
| `warpforth.py` | PyCUDA wrapper for loading and launching the kernel |
| `gpt2_generate.py` | Loads GPT-2, patches attention, generates text |
81 changes: 81 additions & 0 deletions demo/attention.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
\ GPT-2 attention kernel with f32 global memory, f64 shared memory for softmax.
\ Adapted from test/Pipeline/attention.forth — 4 lines changed for f32 access.
\
\ Q/K/V/O are f32 arrays passed as raw byte buffers (i64[]).
\ Global loads/stores use F32@/F32! with 4* byte addressing (f32 = 4 bytes).
\ Shared memory stays f64 for softmax precision, using SF@/SF! with CELLS.

\! kernel attention
\! param Q i64[32768]
\! param K i64[32768]
\! param V i64[32768]
\! param O i64[32768]
\! param SEQ_LEN i64
\! param HEAD_DIM i64
\! shared SCORES f64[1024]
\! shared SCRATCH f64[1024]

\ row = BID-X, t = TID-X
BID-X
TID-X

\ --- Dot product: Q[row,:] . K[t,:] ---
0.0
HEAD_DIM 0 DO
2 PICK HEAD_DIM * I + 4 * Q + F32@
2 PICK HEAD_DIM * I + 4 * K + F32@
F* F+
LOOP
HEAD_DIM S>F FSQRT F/

\ --- Causal mask: if t > row, score = -inf ---
OVER 3 PICK >
IF DROP -1.0e30 THEN

\ --- Store score to shared memory ---
OVER CELLS SCORES + SF!
BARRIER

\ --- Softmax: max reduction (thread 0) ---
TID-X 0= IF
0 CELLS SCORES + SF@
SEQ_LEN 1 DO I CELLS SCORES + SF@ FMAX LOOP
0 CELLS SCRATCH + SF!
THEN
BARRIER

\ --- Softmax: exp(score - max) ---
DUP CELLS SCORES + SF@
0 CELLS SCRATCH + SF@
F- FEXP
OVER CELLS SCORES + SF!
BARRIER

\ --- Softmax: sum reduction (thread 0) ---
TID-X 0= IF
0.0
SEQ_LEN 0 DO I CELLS SCORES + SF@ F+ LOOP
0 CELLS SCRATCH + SF!
THEN
BARRIER

\ --- Softmax: normalize ---
DUP CELLS SCORES + SF@
0 CELLS SCRATCH + SF@
F/
OVER CELLS SCORES + SF!
BARRIER

\ --- V accumulation: O[row,col] = sum_j SCORES[j] * V[j*HD + col] ---
\ Stride over head_dim columns: col = t, t+BDIM-X, t+2*BDIM-X, ...
DUP BEGIN DUP HEAD_DIM < WHILE
0.0
SEQ_LEN 0 DO
I CELLS SCORES + SF@
I HEAD_DIM * 3 PICK + 4 * V + F32@
F* F+
LOOP
OVER 4 PICK HEAD_DIM * + 4 * O + F32!
BDIM-X +
REPEAT
DROP DROP DROP
224 changes: 224 additions & 0 deletions demo/attention.ptx
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
//
// Generated by LLVM NVPTX Back-End
//

.version 6.0
.target sm_70
.address_size 64

// .globl attention
// __wg_attention_0 has been demoted
// __wg_attention_1_$_0 has been demoted

.visible .entry attention(
.param .u64 .ptr .align 1 attention_param_0,
.param .u64 .ptr .align 1 attention_param_1,
.param .u64 .ptr .align 1 attention_param_2,
.param .u64 .ptr .align 1 attention_param_3,
.param .u64 attention_param_4,
.param .u64 attention_param_5
)
{
.reg .pred %p<17>;
.reg .b32 %r<21>;
.reg .f32 %f<7>;
.reg .b64 %rd<76>;
.reg .f64 %fd<63>;
// demoted variable
.shared .align 8 .b8 __wg_attention_0[8192];
// demoted variable
.shared .align 8 .f64 __wg_attention_1_$_0;
ld.param.u64 %rd41, [attention_param_5];
ld.param.u64 %rd40, [attention_param_4];
ld.param.u64 %rd39, [attention_param_3];
ld.param.u64 %rd38, [attention_param_2];
ld.param.u64 %rd42, [attention_param_0];
mov.u32 %r1, %ctaid.x;
ld.param.u64 %rd43, [attention_param_1];
cvt.u64.u32 %rd44, %r1;
mov.u32 %r5, %tid.x;
cvt.u64.u32 %rd72, %r5;
mul.lo.s64 %rd2, %rd41, %rd44;
mul.lo.s64 %rd45, %rd41, %rd72;
neg.s64 %rd66, %rd41;
shl.b64 %rd46, %rd45, 2;
add.s64 %rd65, %rd43, %rd46;
shl.b64 %rd47, %rd2, 2;
add.s64 %rd64, %rd42, %rd47;
mov.f64 %fd58, 0d0000000000000000;
$L__BB0_1:
ld.f32 %f2, [%rd64];
cvt.f64.f32 %fd16, %f2;
ld.f32 %f3, [%rd65];
cvt.f64.f32 %fd17, %f3;
mul.rn.f64 %fd18, %fd16, %fd17;
add.rn.f64 %fd58, %fd58, %fd18;
add.s64 %rd9, %rd66, 1;
xor.b64 %rd48, %rd9, %rd66;
setp.gt.s64 %p1, %rd48, -1;
add.s64 %rd65, %rd65, 4;
add.s64 %rd64, %rd64, 4;
mov.u64 %rd66, %rd9;
@%p1 bra $L__BB0_1;
cvt.u32.u64 %r6, %rd72;
cvt.rn.f64.s64 %fd19, %rd41;
sqrt.rn.f64 %fd20, %fd19;
div.rn.f64 %fd21, %fd58, %fd20;
setp.gt.u32 %p2, %r6, %r1;
shl.b64 %rd49, %rd72, 3;
mov.u64 %rd50, __wg_attention_0;
add.s64 %rd12, %rd49, %rd50;
selp.f64 %fd22, 0dC6293E5939A08CEA, %fd21, %p2;
st.shared.f64 [%rd12], %fd22;
bar.sync 0;
setp.ne.s32 %p3, %r6, 0;
@%p3 bra $L__BB0_6;
ld.shared.f64 %fd60, [__wg_attention_0];
mov.u64 %rd52, __wg_attention_0;
add.s64 %rd68, %rd52, 8;
sub.s64 %rd67, 1, %rd40;
$L__BB0_4:
ld.shared.f64 %fd23, [%rd68];
mov.b64 %rd53, %fd23;
setp.nan.f64 %p4, %fd60, %fd23;
max.f64 %fd24, %fd60, %fd23;
selp.f64 %fd25, 0d7FF8000000000000, %fd24, %p4;
mov.b64 %rd54, %fd60;
setp.eq.s64 %p5, %rd54, 0;
selp.f64 %fd26, %fd60, %fd25, %p5;
setp.eq.s64 %p6, %rd53, 0;
selp.f64 %fd27, %fd23, %fd26, %p6;
setp.eq.f64 %p7, %fd25, 0d0000000000000000;
selp.f64 %fd60, %fd27, %fd25, %p7;
add.s64 %rd17, %rd67, 1;
xor.b64 %rd55, %rd17, %rd67;
setp.gt.s64 %p8, %rd55, -1;
add.s64 %rd68, %rd68, 8;
mov.u64 %rd67, %rd17;
@%p8 bra $L__BB0_4;
st.shared.f64 [__wg_attention_1_$_0], %fd60;
$L__BB0_6:
bar.sync 0;
ld.shared.f64 %fd28, [%rd12];
ld.shared.f64 %fd29, [__wg_attention_1_$_0];
sub.rn.f64 %fd4, %fd28, %fd29;
fma.rn.f64 %fd30, %fd4, 0d3FF71547652B82FE, 0d4338000000000000;
{
.reg .b32 %temp;
mov.b64 {%r2, %temp}, %fd30;
}
mov.f64 %fd31, 0dC338000000000000;
add.rn.f64 %fd32, %fd30, %fd31;
fma.rn.f64 %fd33, %fd32, 0dBFE62E42FEFA39EF, %fd4;
fma.rn.f64 %fd34, %fd32, 0dBC7ABC9E3B39803F, %fd33;
fma.rn.f64 %fd35, %fd34, 0d3E5ADE1569CE2BDF, 0d3E928AF3FCA213EA;
fma.rn.f64 %fd36, %fd35, %fd34, 0d3EC71DEE62401315;
fma.rn.f64 %fd37, %fd36, %fd34, 0d3EFA01997C89EB71;
fma.rn.f64 %fd38, %fd37, %fd34, 0d3F2A01A014761F65;
fma.rn.f64 %fd39, %fd38, %fd34, 0d3F56C16C1852B7AF;
fma.rn.f64 %fd40, %fd39, %fd34, 0d3F81111111122322;
fma.rn.f64 %fd41, %fd40, %fd34, 0d3FA55555555502A1;
fma.rn.f64 %fd42, %fd41, %fd34, 0d3FC5555555555511;
fma.rn.f64 %fd43, %fd42, %fd34, 0d3FE000000000000B;
fma.rn.f64 %fd44, %fd43, %fd34, 0d3FF0000000000000;
fma.rn.f64 %fd45, %fd44, %fd34, 0d3FF0000000000000;
{
.reg .b32 %temp;
mov.b64 {%r3, %temp}, %fd45;
}
{
.reg .b32 %temp;
mov.b64 {%temp, %r4}, %fd45;
}
shl.b32 %r7, %r2, 20;
add.s32 %r8, %r4, %r7;
mov.b64 %fd59, {%r3, %r8};
{
.reg .b32 %temp;
mov.b64 {%temp, %r9}, %fd4;
}
mov.b32 %f4, %r9;
abs.f32 %f1, %f4;
setp.lt.f32 %p9, %f1, 0f4086232B;
@%p9 bra $L__BB0_9;
setp.lt.f64 %p10, %fd4, 0d0000000000000000;
add.rn.f64 %fd46, %fd4, 0d7FF0000000000000;
selp.f64 %fd59, 0d0000000000000000, %fd46, %p10;
setp.geu.f32 %p11, %f1, 0f40874800;
@%p11 bra $L__BB0_9;
shr.u32 %r10, %r2, 31;
add.s32 %r11, %r2, %r10;
shr.s32 %r12, %r11, 1;
shl.b32 %r13, %r12, 20;
add.s32 %r14, %r4, %r13;
mov.b64 %fd47, {%r3, %r14};
sub.s32 %r15, %r2, %r12;
shl.b32 %r16, %r15, 20;
add.s32 %r17, %r16, 1072693248;
mov.b32 %r18, 0;
mov.b64 %fd48, {%r18, %r17};
mul.rn.f64 %fd59, %fd48, %fd47;
$L__BB0_9:
st.shared.f64 [%rd12], %fd59;
bar.sync 0;
@%p3 bra $L__BB0_13;
neg.s64 %rd69, %rd40;
mov.f64 %fd61, 0d0000000000000000;
mov.u64 %rd70, __wg_attention_0;
$L__BB0_11:
ld.shared.f64 %fd50, [%rd70];
add.rn.f64 %fd61, %fd61, %fd50;
add.s64 %rd26, %rd69, 1;
xor.b64 %rd57, %rd26, %rd69;
setp.gt.s64 %p13, %rd57, -1;
add.s64 %rd70, %rd70, 8;
mov.u64 %rd69, %rd26;
@%p13 bra $L__BB0_11;
st.shared.f64 [__wg_attention_1_$_0], %fd61;
$L__BB0_13:
bar.sync 0;
ld.shared.f64 %fd51, [%rd12];
ld.shared.f64 %fd52, [__wg_attention_1_$_0];
div.rn.f64 %fd53, %fd51, %fd52;
st.shared.f64 [%rd12], %fd53;
bar.sync 0;
setp.le.s64 %p14, %rd41, %rd72;
@%p14 bra $L__BB0_18;
mov.u32 %r20, %ntid.x;
cvt.u64.u32 %rd19, %r20;
shl.b64 %rd58, %rd72, 2;
add.s64 %rd71, %rd38, %rd58;
mul.wide.u32 %rd21, %r20, 4;
shl.b64 %rd22, %rd41, 2;
neg.s64 %rd23, %rd40;
$L__BB0_15:
mov.f64 %fd62, 0d0000000000000000;
mov.u64 %rd73, %rd23;
mov.u64 %rd74, %rd71;
mov.u64 %rd75, %rd50;
$L__BB0_16:
ld.shared.f64 %fd55, [%rd75];
ld.f32 %f5, [%rd74];
cvt.f64.f32 %fd56, %f5;
mul.rn.f64 %fd57, %fd55, %fd56;
add.rn.f64 %fd62, %fd62, %fd57;
add.s64 %rd33, %rd73, 1;
xor.b64 %rd60, %rd33, %rd73;
setp.gt.s64 %p15, %rd60, -1;
add.s64 %rd75, %rd75, 8;
add.s64 %rd74, %rd74, %rd22;
mov.u64 %rd73, %rd33;
@%p15 bra $L__BB0_16;
add.s64 %rd61, %rd72, %rd2;
shl.b64 %rd62, %rd61, 2;
add.s64 %rd63, %rd62, %rd39;
cvt.rn.f32.f64 %f6, %fd62;
st.f32 [%rd63], %f6;
add.s64 %rd72, %rd72, %rd19;
add.s64 %rd71, %rd71, %rd21;
setp.lt.s64 %p16, %rd72, %rd41;
@%p16 bra $L__BB0_15;
$L__BB0_18:
ret;

}
Loading