diff --git a/CLAUDE.md b/CLAUDE.md index 373d1c9..1dc6a79 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -87,12 +87,13 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). - **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations. - **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i64[]` becomes a `memref` argument; `\! param i64` becomes an `i64` argument. `\! param f64[]` becomes a `memref` argument; `\! param f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). - **Shared Memory**: `\! shared i64[]` or `\! shared f64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions. - **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer - **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion +- **Local Variables**: `{ a b c -- }` at the start of a word definition binds read-only locals. Pops values from the stack in reverse name order (c, b, a) using `forth.pop`, stores SSA values. Referencing a local emits `forth.push_value`. SSA values from the entry block dominate all control flow, so locals work across IF/ELSE/THEN, loops, etc. On GPU, locals map directly to registers. - **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call` ## Conventions diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index aebb7e8..e64dab8 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -409,6 +409,16 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, Location loc) { Type stackType = forth::StackType::get(context); + // Check if word is a local variable (only valid inside word definitions) + if (inWordDefinition) { + auto it = localVars.find(word); + if (it != localVars.end()) { + return builder + .create(loc, stackType, inputStack, it->second) + .getOutputStack(); + } + } + // Check if word is a param name (only valid outside word definitions) if (!inWordDefinition) { for (const auto ¶m : paramDecls) { @@ -994,6 +1004,11 @@ LogicalResult ForthParser::parseBody(Value &stack) { Value step = popOp.getValue(); emitLoopEnd(loc, ctx, step, stack); + //=== { outside word definition === + } else if (currentToken.text == "{") { + return emitError( + "local variables can only be declared inside a word definition"); + //=== Normal word === } else { Value newStack = emitOperation(currentToken.text, stack, loc); @@ -1014,6 +1029,84 @@ LogicalResult ForthParser::parseBody(Value &stack) { // Word definition and top-level parsing. //===----------------------------------------------------------------------===// +LogicalResult ForthParser::parseLocals(Value &stack) { + // If current token is not '{', no locals to parse + if (currentToken.kind != Token::Kind::Word || currentToken.text != "{") + return success(); + + Location loc = getLoc(); + consume(); // consume '{' + + // Collect local names until '--' or '}' + SmallVector names; + while (currentToken.kind != Token::Kind::EndOfFile) { + if (currentToken.kind == Token::Kind::Word && currentToken.text == "--") + break; + if (currentToken.kind == Token::Kind::Word && currentToken.text == "}") + break; + + if (currentToken.kind != Token::Kind::Word) + return emitError("expected local variable name in { ... }"); + + std::string name = currentToken.text; // already uppercased by lexer + + // Check for duplicate local names + for (const auto &existing : names) { + if (existing == name) + return emitError("duplicate local variable name: " + name); + } + + // Check for conflicts with param names + for (const auto ¶m : paramDecls) { + if (param.name == name) + return emitError("local variable name '" + name + + "' conflicts with parameter name"); + } + + // Check for conflicts with shared names + for (const auto &shared : sharedDecls) { + if (shared.name == name) + return emitError("local variable name '" + name + + "' conflicts with shared memory name"); + } + + names.push_back(name); + consume(); + } + + // Skip '--' and output names until '}'. Per ANS Forth, output names are + // documentation-only and have no semantic effect; we intentionally ignore + // them. + if (currentToken.kind == Token::Kind::Word && currentToken.text == "--") { + consume(); // consume '--' + while (currentToken.kind != Token::Kind::EndOfFile) { + if (currentToken.kind == Token::Kind::Word && currentToken.text == "}") + break; + consume(); // skip output names (ignored) + } + } + + if (currentToken.kind != Token::Kind::Word || currentToken.text != "}") + return emitError("expected '}' to close local variable declaration"); + + consume(); // consume '}' + + if (names.empty()) + return success(); + + // Pop values in reverse order: { a b c -- } with stack ( 1 2 3 ) + // pops 3->c, 2->b, 1->a + Type i64Type = builder.getI64Type(); + Type stackType = forth::StackType::get(context); + for (int i = names.size() - 1; i >= 0; --i) { + auto popOp = builder.create(loc, stackType, i64Type, stack); + stack = popOp.getOutputStack(); + localVars[names[i]] = popOp.getValue(); + } + + return success(); +} + LogicalResult ForthParser::parseWordDefinition() { Location loc = getLoc(); auto savedInsertionPoint = builder.saveInsertionPoint(); @@ -1039,6 +1132,10 @@ LogicalResult ForthParser::parseWordDefinition() { Value resultStack = entryBlock->getArgument(0); builder.setInsertionPointToStart(entryBlock); + // Parse local variable declarations (if any) + if (failed(parseLocals(resultStack))) + return failure(); + // Parse word body until ';' if (failed(parseBody(resultStack))) return failure(); @@ -1057,6 +1154,7 @@ LogicalResult ForthParser::parseWordDefinition() { consume(); // consume ';' inWordDefinition = false; + localVars.clear(); // Restore insertion point builder.restoreInsertionPoint(savedInsertionPoint); diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.h b/lib/Translation/ForthToMLIR/ForthToMLIR.h index 6aa4809..4059e02 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.h +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.h @@ -101,6 +101,7 @@ class ForthParser { std::vector paramDecls; std::vector sharedDecls; llvm::StringMap sharedAllocs; + llvm::StringMap localVars; std::string kernelName; const char *headerEndPtr = nullptr; bool inWordDefinition = false; @@ -159,6 +160,9 @@ class ForthParser { void emitLoopEnd(Location loc, const LoopContext &ctx, Value step, Value &stack); + /// Parse local variable declarations: { a b c -- } + LogicalResult parseLocals(Value &stack); + /// Parse a user-defined word definition. LogicalResult parseWordDefinition(); }; diff --git a/test/Pipeline/local-variables.forth b/test/Pipeline/local-variables.forth new file mode 100644 index 0000000..4e907c2 --- /dev/null +++ b/test/Pipeline/local-variables.forth @@ -0,0 +1,10 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s + +\ Verify that local variables compile through the full pipeline to gpu.binary. +\ CHECK: gpu.binary @warpforth_module + +\! kernel main +\! param DATA i64[256] +: ADD3 { a b c -- } a b + c + ; +1 2 3 ADD3 +GLOBAL-ID CELLS DATA + ! diff --git a/test/Translation/Forth/local-variables-control-flow.forth b/test/Translation/Forth/local-variables-control-flow.forth new file mode 100644 index 0000000..1cd259f --- /dev/null +++ b/test/Translation/Forth/local-variables-control-flow.forth @@ -0,0 +1,19 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Test that locals work across IF/ELSE/THEN control flow. +\ SSA values defined in the entry block dominate all subsequent blocks. + +\ CHECK: func.func private @CLAMP(%arg0: !forth.stack) -> !forth.stack { +\ CHECK: forth.pop +\ CHECK: forth.pop +\ CHECK: forth.pop +\ CHECK: forth.push_value +\ CHECK: forth.push_value +\ CHECK: forth.push_value + +\! kernel main +: CLAMP { val lo hi -- } + val lo < IF lo ELSE + val hi > IF hi ELSE + val THEN THEN ; +0 10 5 CLAMP diff --git a/test/Translation/Forth/local-variables-error-outside-word.forth b/test/Translation/Forth/local-variables-error-outside-word.forth new file mode 100644 index 0000000..1fc955a --- /dev/null +++ b/test/Translation/Forth/local-variables-error-outside-word.forth @@ -0,0 +1,4 @@ +\ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s +\ CHECK: local variables can only be declared inside a word definition +\! kernel main +{ x y -- } diff --git a/test/Translation/Forth/local-variables-error-param-conflict.forth b/test/Translation/Forth/local-variables-error-param-conflict.forth new file mode 100644 index 0000000..edfb7dc --- /dev/null +++ b/test/Translation/Forth/local-variables-error-param-conflict.forth @@ -0,0 +1,6 @@ +\ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s +\ CHECK: local variable name 'DATA' conflicts with parameter name +\! kernel main +\! param DATA i64[256] +: BAD { data -- } data ; +BAD diff --git a/test/Translation/Forth/local-variables-error-shared-conflict.forth b/test/Translation/Forth/local-variables-error-shared-conflict.forth new file mode 100644 index 0000000..12b8fba --- /dev/null +++ b/test/Translation/Forth/local-variables-error-shared-conflict.forth @@ -0,0 +1,6 @@ +\ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s +\ CHECK: local variable name 'BUF' conflicts with shared memory name +\! kernel main +\! shared BUF i64[64] +: BAD { buf -- } buf ; +BAD diff --git a/test/Translation/Forth/local-variables-error.forth b/test/Translation/Forth/local-variables-error.forth new file mode 100644 index 0000000..f73121b --- /dev/null +++ b/test/Translation/Forth/local-variables-error.forth @@ -0,0 +1,5 @@ +\ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s +\ CHECK: duplicate local variable name: X +\! kernel main +: BAD { x y x -- } x ; +BAD diff --git a/test/Translation/Forth/local-variables.forth b/test/Translation/Forth/local-variables.forth new file mode 100644 index 0000000..470d9d8 --- /dev/null +++ b/test/Translation/Forth/local-variables.forth @@ -0,0 +1,27 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Test basic local variable binding and reference. + +\ CHECK: func.func private @ADD3(%arg0: !forth.stack) -> !forth.stack { +\ CHECK: forth.pop %arg0 : !forth.stack -> !forth.stack, i64 +\ CHECK: forth.pop %{{.*}} : !forth.stack -> !forth.stack, i64 +\ CHECK: forth.pop %{{.*}} : !forth.stack -> !forth.stack, i64 +\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack +\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack +\ CHECK: forth.addi +\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack +\ CHECK: forth.addi +\ CHECK: return + +\ CHECK: func.func private @SWAP2(%arg0: !forth.stack) -> !forth.stack { +\ CHECK: forth.pop %arg0 : !forth.stack -> !forth.stack, i64 +\ CHECK: forth.pop %{{.*}} : !forth.stack -> !forth.stack, i64 +\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack +\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack +\ CHECK: return + +\! kernel main +: ADD3 { a b c -- } a b + c + ; +: SWAP2 { x y -- } y x ; +1 2 3 ADD3 +10 20 SWAP2