From 42a6f69c3b2e14e7d402d42a95acc65dcbfba6f1 Mon Sep 17 00:00:00 2001 From: Alex Cameron Date: Sat, 21 Feb 2026 18:14:31 +0900 Subject: [PATCH 1/2] feat: add float math intrinsics (FEXP, FSQRT, FLOG, FABS, FNEG, FMAX, FMIN) --- CLAUDE.md | 2 +- include/warpforth/Conversion/Passes.td | 3 +- include/warpforth/Dialect/Forth/ForthOps.td | 65 +++++++++++++++++ lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/ForthToMemRef/CMakeLists.txt | 1 + .../ForthToMemRef/ForthToMemRef.cpp | 61 +++++++++++++++- lib/Conversion/Passes.cpp | 8 ++- lib/Translation/ForthToMLIR/ForthToMLIR.cpp | 21 ++++++ .../ForthToMemRef/float-math-intrinsics.mlir | 70 +++++++++++++++++++ test/Pipeline/float-math-intrinsics.forth | 12 ++++ .../Forth/float-math-intrinsics.forth | 21 ++++++ 11 files changed, 260 insertions(+), 5 deletions(-) create mode 100644 test/Conversion/ForthToMemRef/float-math-intrinsics.mlir create mode 100644 test/Pipeline/float-math-intrinsics.forth create mode 100644 test/Translation/Forth/float-math-intrinsics.forth diff --git a/CLAUDE.md b/CLAUDE.md index 1dc6a79..2472d89 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -87,7 +87,7 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). - **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations. - **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i64[]` becomes a `memref` argument; `\! param i64` becomes an `i64` argument. `\! param f64[]` becomes a `memref` argument; `\! param f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). - **Shared Memory**: `\! shared i64[]` or `\! shared f64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions. diff --git a/include/warpforth/Conversion/Passes.td b/include/warpforth/Conversion/Passes.td index 38342a1..be86499 100644 --- a/include/warpforth/Conversion/Passes.td +++ b/include/warpforth/Conversion/Passes.td @@ -28,7 +28,8 @@ def ConvertForthToMemRef let dependentDialects = ["mlir::memref::MemRefDialect", "mlir::arith::ArithDialect", "mlir::LLVM::LLVMDialect", - "mlir::cf::ControlFlowDialect"]; + "mlir::cf::ControlFlowDialect", + "mlir::math::MathDialect"]; } def ConvertForthToGPU : Pass<"convert-forth-to-gpu", "mlir::ModuleOp"> { diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index 1b12ec4..e1b0fa1 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -212,6 +212,71 @@ def Forth_DivFOp : Forth_StackOpBase<"divf"> { }]; } +//===----------------------------------------------------------------------===// +// Float math intrinsic operations. +//===----------------------------------------------------------------------===// + +def Forth_ExpFOp : Forth_StackOpBase<"expf"> { + let summary = "Exponential of top stack element (float)"; + let description = [{ + Pops an i64 (f64 bit pattern), bitcasts to f64, computes e^x, + bitcasts result back to i64. + Forth semantics: ( f -- exp(f) ) + }]; +} + +def Forth_SqrtFOp : Forth_StackOpBase<"sqrtf"> { + let summary = "Square root of top stack element (float)"; + let description = [{ + Pops an i64 (f64 bit pattern), bitcasts to f64, computes sqrt(x), + bitcasts result back to i64. + Forth semantics: ( f -- sqrt(f) ) + }]; +} + +def Forth_LogFOp : Forth_StackOpBase<"logf"> { + let summary = "Natural logarithm of top stack element (float)"; + let description = [{ + Pops an i64 (f64 bit pattern), bitcasts to f64, computes ln(x), + bitcasts result back to i64. + Forth semantics: ( f -- log(f) ) + }]; +} + +def Forth_AbsFOp : Forth_StackOpBase<"absf"> { + let summary = "Absolute value of top stack element (float)"; + let description = [{ + Pops an i64 (f64 bit pattern), bitcasts to f64, computes |x|, + bitcasts result back to i64. + Forth semantics: ( f -- |f| ) + }]; +} + +def Forth_NegFOp : Forth_StackOpBase<"negf"> { + let summary = "Negate top stack element (float)"; + let description = [{ + Pops an i64 (f64 bit pattern), bitcasts to f64, negates, + bitcasts result back to i64. + Forth semantics: ( f -- -f ) + }]; +} + +def Forth_MaxFOp : Forth_StackOpBase<"maxf"> { + let summary = "Maximum of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, computes max, bitcasts result back to i64. + Forth semantics: ( f1 f2 -- max(f1,f2) ) + }]; +} + +def Forth_MinFOp : Forth_StackOpBase<"minf"> { + let summary = "Minimum of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, computes min, bitcasts result back to i64. + Forth semantics: ( f1 f2 -- min(f1,f2) ) + }]; +} + //===----------------------------------------------------------------------===// // Bitwise operations. //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 329ecf0..646cdb2 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(MLIRConversionPasses MLIRGPUDialect MLIRGPUToNVVMTransforms MLIRGPUTransforms + MLIRMathToLLVM MLIRReconcileUnrealizedCasts MLIRTransforms ) diff --git a/lib/Conversion/ForthToMemRef/CMakeLists.txt b/lib/Conversion/ForthToMemRef/CMakeLists.txt index ff651f1..1f85b95 100644 --- a/lib/Conversion/ForthToMemRef/CMakeLists.txt +++ b/lib/Conversion/ForthToMemRef/CMakeLists.txt @@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRForthToMemRefConversion MLIRLLVMDialect MLIRFuncDialect MLIRControlFlowDialect + MLIRMathDialect MLIRForth ) diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index e8bf21d..550b77e 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" @@ -514,6 +515,60 @@ using MulFOpConversion = using DivFOpConversion = BinaryArithOpConversion; +// Float binary intrinsics (max/min) +using MaxFOpConversion = + BinaryArithOpConversion; +using MinFOpConversion = + BinaryArithOpConversion; + +/// Base template for unary float operations. +/// Pops one value, applies operation, pushes result: (f -- result) +/// Bitcasts i64->f64 before the op and f64->i64 after. +template +struct UnaryFloatOpConversion : public OpConversionPattern { + UnaryFloatOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = + typename OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + // Load value from top of stack + Value a = rewriter.create(loc, memref, stackPtr); + + // Bitcast i64 -> f64 + auto f64Type = rewriter.getF64Type(); + Value aF = rewriter.create(loc, f64Type, a); + + // Apply math/arith op + Value resF = rewriter.create(loc, aF); + + // Bitcast f64 -> i64 + Value result = + rewriter.create(loc, rewriter.getI64Type(), resF); + + // Store result at same position (SP unchanged — unary op) + rewriter.create(loc, result, memref, stackPtr); + + rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); + return success(); + } +}; + +// Float unary intrinsics +using ExpFOpConversion = UnaryFloatOpConversion; +using SqrtFOpConversion = UnaryFloatOpConversion; +using LogFOpConversion = UnaryFloatOpConversion; +using AbsFOpConversion = UnaryFloatOpConversion; +using NegFOpConversion = UnaryFloatOpConversion; + /// Base template for binary comparison operations. /// Pops two values, compares, pushes -1 (true) or 0 (false): (a b -- flag) /// When IsFloat=true, bitcasts i64->f64 before comparing. @@ -1153,7 +1208,8 @@ struct ConvertForthToMemRefPass // Mark MemRef, Arith, LLVM, and CF dialects as legal target.addLegalDialect(); + LLVM::LLVMDialect, cf::ControlFlowDialect, + math::MathDialect>(); // Mark IntrinsicOp and BarrierOp as legal (to be lowered later) target.addLegalOp(); @@ -1205,6 +1261,9 @@ struct ConvertForthToMemRefPass ModOpConversion, // Float arithmetic AddFOpConversion, SubFOpConversion, MulFOpConversion, DivFOpConversion, + // Float math intrinsics + ExpFOpConversion, SqrtFOpConversion, LogFOpConversion, AbsFOpConversion, + NegFOpConversion, MaxFOpConversion, MinFOpConversion, // Bitwise AndOpConversion, OrOpConversion, XorOpConversion, NotOpConversion, LshiftOpConversion, RshiftOpConversion, diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 5e59eb9..d44a892 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -6,6 +6,7 @@ #include "warpforth/Conversion/Passes.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -40,10 +41,13 @@ void buildWarpForthPipeline(OpPassManager &pm) { pm.addNestedPass( createConvertGpuOpsToNVVMOps(gpuToNVVMOptions)); - // Stage 6: Lower NVVM to LLVM + // Stage 6: Lower math ops to LLVM intrinsics inside GPU module + pm.addNestedPass(createConvertMathToLLVMPass()); + + // Stage 7: Lower NVVM to LLVM pm.addPass(createConvertNVVMToLLVMPass()); - // Stage 7: Reconcile type conversions + // Stage 8: Reconcile type conversions pm.addPass(createReconcileUnrealizedCastsPass()); // Stage 8: Compile GPU module to PTX binary diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index 49edced..409e737 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -531,6 +531,27 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, } else if (word == "F/") { return builder.create(loc, stackType, inputStack) .getResult(); + } else if (word == "FEXP") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "FSQRT") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "FLOG") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "FABS") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "FNEG") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "FMAX") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "FMIN") { + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "MOD") { return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "AND") { diff --git a/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir b/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir new file mode 100644 index 0000000..3153538 --- /dev/null +++ b/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir @@ -0,0 +1,70 @@ +// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s + +// CHECK-LABEL: func.func private @main + +// expf: load, bitcast i64->f64, math.exp, bitcast f64->i64, store (SP unchanged) +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: math.exp %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// sqrtf: load, bitcast, math.sqrt, bitcast, store +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: math.sqrt %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// logf: load, bitcast, math.log, bitcast, store +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: math.log %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// absf: load, bitcast, math.absf, bitcast, store +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: math.absf %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// negf: load, bitcast, arith.negf, bitcast, store +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.negf %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// maxf: binary — pop two, bitcast, arith.maximumf, bitcast, store +// CHECK: memref.load +// CHECK: arith.subi +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.maximumf %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// minf: binary — pop two, bitcast, arith.minimumf, bitcast, store +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.minimumf %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 + +module { + func.func private @main() { + %0 = forth.stack !forth.stack + %1 = forth.constant %0(1.000000e+00 : f64) : !forth.stack -> !forth.stack + %2 = forth.expf %1 : !forth.stack -> !forth.stack + %3 = forth.sqrtf %2 : !forth.stack -> !forth.stack + %4 = forth.logf %3 : !forth.stack -> !forth.stack + %5 = forth.absf %4 : !forth.stack -> !forth.stack + %6 = forth.negf %5 : !forth.stack -> !forth.stack + %7 = forth.constant %6(2.000000e+00 : f64) : !forth.stack -> !forth.stack + %8 = forth.maxf %7 : !forth.stack -> !forth.stack + %9 = forth.minf %8 : !forth.stack -> !forth.stack + return + } +} diff --git a/test/Pipeline/float-math-intrinsics.forth b/test/Pipeline/float-math-intrinsics.forth new file mode 100644 index 0000000..15d4e7f --- /dev/null +++ b/test/Pipeline/float-math-intrinsics.forth @@ -0,0 +1,12 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s + +\ Verify float math intrinsics lower through the full pipeline to gpu.binary +\ CHECK: gpu.binary @warpforth_module + +\! kernel main +\! param data f64[256] +GLOBAL-ID CELLS data + F@ +FABS FEXP FSQRT FLOG FNEG +GLOBAL-ID CELLS data + F@ +FMAX FMIN +GLOBAL-ID CELLS data + F! diff --git a/test/Translation/Forth/float-math-intrinsics.forth b/test/Translation/Forth/float-math-intrinsics.forth new file mode 100644 index 0000000..bdf1f6b --- /dev/null +++ b/test/Translation/Forth/float-math-intrinsics.forth @@ -0,0 +1,21 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Verify float math intrinsic ops parse correctly + +\ Unary ops +\ CHECK: %[[S0:.*]] = forth.stack +\ CHECK: %[[S1:.*]] = forth.constant %[[S0]] +\ CHECK: %[[S2:.*]] = forth.expf %[[S1]] +\ CHECK: %[[S3:.*]] = forth.sqrtf %[[S2]] +\ CHECK: %[[S4:.*]] = forth.logf %[[S3]] +\ CHECK: %[[S5:.*]] = forth.absf %[[S4]] +\ CHECK: %[[S6:.*]] = forth.negf %[[S5]] + +\ Binary ops +\ CHECK: %[[S7:.*]] = forth.constant %[[S6]] +\ CHECK: %[[S8:.*]] = forth.maxf %[[S7]] +\ CHECK: %[[S9:.*]] = forth.minf %[[S8]] + +\! kernel main +1.0 FEXP FSQRT FLOG FABS FNEG +2.0 FMAX FMIN From 32eef876c555e3963c3c9aa1351328c25410fb1f Mon Sep 17 00:00:00 2001 From: Alex Cameron Date: Sat, 21 Feb 2026 21:20:12 +0900 Subject: [PATCH 2/2] fix(pipeline): correct stage numbering comment --- lib/Conversion/Passes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index d44a892..ad944eb 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -50,7 +50,7 @@ void buildWarpForthPipeline(OpPassManager &pm) { // Stage 8: Reconcile type conversions pm.addPass(createReconcileUnrealizedCastsPass()); - // Stage 8: Compile GPU module to PTX binary + // Stage 9: Compile GPU module to PTX binary GpuModuleToBinaryPassOptions binaryOptions; binaryOptions.compilationTarget = "isa"; // Output PTX assembly pm.addPass(createGpuModuleToBinaryPass(binaryOptions));