Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Requires MLIR/LLVM with `MLIR_DIR` and `LLVM_DIR` configured in CMake.
./build/bin/warpforth-translate --mlir-to-ptx > kernel.ptx

# Execute PTX on GPU
./warpforth-runner kernel.ptx --param i64[]:1,2,3 --param i64:42 --output-param 0 --output-count 3
./warpforth-runner kernel.ptx --param i32[]:1,2,3 --param i32:42 --output-param 0 --output-count 3
```

## Adding New Operations
Expand Down Expand Up @@ -87,11 +87,12 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations.
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. `\! param <name> f64[<N>]` becomes a `memref<Nxf64>` argument; `\! param <name> f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` or `\! shared <name> f64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions.
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global i32 memory), `F@ F!` (global f32 memory), `S@ S!` (shared i32 memory), `SF@ SF!` (shared f32 memory), `HF@ HF!` (global f16 memory), `BF@ BF!` (global bf16 memory), `I8@ I8!` (global i8 memory), `I16@ I16!` (global i16 memory), `SHF@ SHF!` (shared f16 memory), `SBF@ SBF!` (shared bf16 memory), `SI8@ SI8!` (shared i8 memory), `SI16@ SI16!` (shared i16 memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Cell Size**: 32-bit arithmetic (i32/f32). CELLS = 4 bytes. The stack is `memref<256xi64>` because GPU pointers are 64-bit; arithmetic values are truncated to i32 at stack-load boundaries and sign-extended back to i64 at stack-store boundaries. LLVM optimization eliminates this overhead.
- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f32 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i32 bit patterns (sign-extended to i64); F-prefixed words perform bitcast before/after operations.
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i32[<N>]` becomes a `memref<Nxi32>` argument; `\! param <name> i32` becomes an `i32` argument. `\! param <name> f32[<N>]` becomes a `memref<Nxf32>` argument; `\! param <name> f32` becomes an `f32` argument (bitcast to i32, sign-extended to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i32[<N>]` or `\! shared <name> f32[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i32 or `SF@`/`SF!` for f32 shared accesses. Cannot be referenced inside word definitions.
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer. Arithmetic uses i32/f32 with trunci/extsi at stack boundaries. Narrow-type load/store words (f16, bf16, i8, i16) widen through i32/f32.
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
- **Local Variables**: `{ a b c -- }` at the start of a word definition binds read-only locals. Pops values from the stack in reverse name order (c, b, a) using `forth.pop`, stores SSA values. Referencing a local emits `forth.push_value`. SSA values from the entry block dominate all control flow, so locals work across IF/ELSE/THEN, loops, etc. On GPU, locals map directly to registers.
- **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call`
Expand Down
12 changes: 6 additions & 6 deletions gpu_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ParamDecl:
name: str
is_array: bool
size: int # 0 for scalars
base_type: str = "i64" # "i64" or "f64"
base_type: str = "i32" # "i32" or "f32"


class CompileError(Exception):
Expand Down Expand Up @@ -320,7 +320,7 @@ def _parse_array_type(type_spec: str) -> tuple[str, int]:
raise ValueError(msg)
base, size_str = type_spec[:-1].split("[", 1)
base_lower = base.lower()
if base_lower not in ("i64", "f64"):
if base_lower not in ("i32", "f32"):
msg = f"Unsupported base type: {base}"
raise ValueError(msg)
return base_lower, int(size_str)
Expand Down Expand Up @@ -377,7 +377,7 @@ def _parse_param_declarations(forth_source: str) -> list[ParamDecl]:
decls.append(ParamDecl(name=name, is_array=True, size=size, base_type=base_type))
else:
base_type = type_spec.lower()
if base_type not in ("i64", "f64"):
if base_type not in ("i32", "f32"):
msg = f"Unsupported scalar type: {type_spec}"
raise ValueError(msg)
decls.append(ParamDecl(name=name, is_array=False, size=0, base_type=base_type))
Expand Down Expand Up @@ -453,13 +453,13 @@ def run(
if not isinstance(values, list):
msg = f"Array param '{decl.name}' expects a list, got {type(values).__name__}"
raise TypeError(msg)
zero = 0.0 if decl.base_type == "f64" else 0
zero = 0.0 if decl.base_type == "f32" else 0
buf = [zero] * decl.size
for i, v in enumerate(values):
buf[i] = v
cmd_parts.extend(["--param", f"{decl.base_type}[]:{','.join(str(v) for v in buf)}"])
else:
value = params.get(decl.name, 0.0 if decl.base_type == "f64" else 0)
value = params.get(decl.name, 0.0 if decl.base_type == "f32" else 0)
if isinstance(value, list):
msg = f"Scalar param '{decl.name}' expects a scalar, got list"
raise TypeError(msg)
Expand All @@ -484,7 +484,7 @@ def run(

# Parse CSV output — type depends on the output param
out_type = decls[output_param].base_type
parse = float if out_type == "f64" else int
parse = float if out_type == "f32" else int
return [parse(v) for v in stdout.strip().split(",")]


Expand Down
Loading