diff --git a/CLAUDE.md b/CLAUDE.md index 2472d89..73a28cb 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,7 +49,7 @@ Requires MLIR/LLVM with `MLIR_DIR` and `LLVM_DIR` configured in CMake. ./build/bin/warpforth-translate --mlir-to-ptx > kernel.ptx # Execute PTX on GPU -./warpforth-runner kernel.ptx --param i64[]:1,2,3 --param i64:42 --output-param 0 --output-count 3 +./warpforth-runner kernel.ptx --param i32[]:1,2,3 --param i32:42 --output-param 0 --output-count 3 ``` ## Adding New Operations @@ -87,11 +87,12 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). -- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations. -- **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i64[]` becomes a `memref` argument; `\! param i64` becomes an `i64` argument. `\! param f64[]` becomes a `memref` argument; `\! param f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). -- **Shared Memory**: `\! shared i64[]` or `\! shared f64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions. -- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer +- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global i32 memory), `F@ F!` (global f32 memory), `S@ S!` (shared i32 memory), `SF@ SF!` (shared f32 memory), `HF@ HF!` (global f16 memory), `BF@ BF!` (global bf16 memory), `I8@ I8!` (global i8 memory), `I16@ I16!` (global i16 memory), `SHF@ SHF!` (shared f16 memory), `SBF@ SBF!` (shared bf16 memory), `SI8@ SI8!` (shared i8 memory), `SI16@ SI16!` (shared i16 memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Cell Size**: 32-bit arithmetic (i32/f32). CELLS = 4 bytes. The stack is `memref<256xi64>` because GPU pointers are 64-bit; arithmetic values are truncated to i32 at stack-load boundaries and sign-extended back to i64 at stack-store boundaries. LLVM optimization eliminates this overhead. +- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f32 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i32 bit patterns (sign-extended to i64); F-prefixed words perform bitcast before/after operations. +- **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i32[]` becomes a `memref` argument; `\! param i32` becomes an `i32` argument. `\! param f32[]` becomes a `memref` argument; `\! param f32` becomes an `f32` argument (bitcast to i32, sign-extended to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). +- **Shared Memory**: `\! shared i32[]` or `\! shared f32[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i32 or `SF@`/`SF!` for f32 shared accesses. Cannot be referenced inside word definitions. +- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer. Arithmetic uses i32/f32 with trunci/extsi at stack boundaries. Narrow-type load/store words (f16, bf16, i8, i16) widen through i32/f32. - **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion - **Local Variables**: `{ a b c -- }` at the start of a word definition binds read-only locals. Pops values from the stack in reverse name order (c, b, a) using `forth.pop`, stores SSA values. Referencing a local emits `forth.push_value`. SSA values from the entry block dominate all control flow, so locals work across IF/ELSE/THEN, loops, etc. On GPU, locals map directly to registers. - **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call` diff --git a/gpu_test/conftest.py b/gpu_test/conftest.py index 1fc20e4..5e3a6a7 100644 --- a/gpu_test/conftest.py +++ b/gpu_test/conftest.py @@ -39,7 +39,7 @@ class ParamDecl: name: str is_array: bool size: int # 0 for scalars - base_type: str = "i64" # "i64" or "f64" + base_type: str = "i32" # "i32" or "f32" class CompileError(Exception): @@ -320,7 +320,7 @@ def _parse_array_type(type_spec: str) -> tuple[str, int]: raise ValueError(msg) base, size_str = type_spec[:-1].split("[", 1) base_lower = base.lower() - if base_lower not in ("i64", "f64"): + if base_lower not in ("i32", "f32"): msg = f"Unsupported base type: {base}" raise ValueError(msg) return base_lower, int(size_str) @@ -377,7 +377,7 @@ def _parse_param_declarations(forth_source: str) -> list[ParamDecl]: decls.append(ParamDecl(name=name, is_array=True, size=size, base_type=base_type)) else: base_type = type_spec.lower() - if base_type not in ("i64", "f64"): + if base_type not in ("i32", "f32"): msg = f"Unsupported scalar type: {type_spec}" raise ValueError(msg) decls.append(ParamDecl(name=name, is_array=False, size=0, base_type=base_type)) @@ -453,13 +453,13 @@ def run( if not isinstance(values, list): msg = f"Array param '{decl.name}' expects a list, got {type(values).__name__}" raise TypeError(msg) - zero = 0.0 if decl.base_type == "f64" else 0 + zero = 0.0 if decl.base_type == "f32" else 0 buf = [zero] * decl.size for i, v in enumerate(values): buf[i] = v cmd_parts.extend(["--param", f"{decl.base_type}[]:{','.join(str(v) for v in buf)}"]) else: - value = params.get(decl.name, 0.0 if decl.base_type == "f64" else 0) + value = params.get(decl.name, 0.0 if decl.base_type == "f32" else 0) if isinstance(value, list): msg = f"Scalar param '{decl.name}' expects a scalar, got list" raise TypeError(msg) @@ -484,7 +484,7 @@ def run( # Parse CSV output — type depends on the output param out_type = decls[output_param].base_type - parse = float if out_type == "f64" else int + parse = float if out_type == "f32" else int return [parse(v) for v in stdout.strip().split(",")] diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index e6e9376..4a34731 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -19,7 +19,7 @@ def test_addition(kernel_runner: KernelRunner) -> None: """3 + 4 = 7.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA i64[256]\n3 4 +\n0 CELLS DATA + !", + forth_source="\\! kernel main\n\\! param DATA i32[256]\n3 4 +\n0 CELLS DATA + !", ) assert result[0] == 7 @@ -27,7 +27,7 @@ def test_addition(kernel_runner: KernelRunner) -> None: def test_subtraction(kernel_runner: KernelRunner) -> None: """10 - 3 = 7.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA i64[256]\n10 3 -\n0 CELLS DATA + !", + forth_source="\\! kernel main\n\\! param DATA i32[256]\n10 3 -\n0 CELLS DATA + !", ) assert result[0] == 7 @@ -35,7 +35,7 @@ def test_subtraction(kernel_runner: KernelRunner) -> None: def test_multiplication(kernel_runner: KernelRunner) -> None: """6 * 7 = 42.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA i64[256]\n6 7 *\n0 CELLS DATA + !", + forth_source="\\! kernel main\n\\! param DATA i32[256]\n6 7 *\n0 CELLS DATA + !", ) assert result[0] == 42 @@ -43,7 +43,7 @@ def test_multiplication(kernel_runner: KernelRunner) -> None: def test_division(kernel_runner: KernelRunner) -> None: """42 / 6 = 7.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA i64[256]\n42 6 /\n0 CELLS DATA + !", + forth_source="\\! kernel main\n\\! param DATA i32[256]\n42 6 /\n0 CELLS DATA + !", ) assert result[0] == 7 @@ -51,7 +51,7 @@ def test_division(kernel_runner: KernelRunner) -> None: def test_modulo(kernel_runner: KernelRunner) -> None: """17 MOD 5 = 2.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA i64[256]\n17 5 MOD\n0 CELLS DATA + !", + forth_source="\\! kernel main\n\\! param DATA i32[256]\n17 5 MOD\n0 CELLS DATA + !", ) assert result[0] == 2 @@ -63,7 +63,7 @@ def test_dup(kernel_runner: KernelRunner) -> None: """DUP duplicates top of stack: 5 DUP → [5, 5].""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n5 DUP\n1 CELLS DATA + !\n0 CELLS DATA + !" + "\\! kernel main\n\\! param DATA i32[256]\n5 DUP\n1 CELLS DATA + !\n0 CELLS DATA + !" ), output_count=2, ) @@ -74,7 +74,7 @@ def test_swap(kernel_runner: KernelRunner) -> None: """SWAP exchanges top two: 1 2 SWAP → [2, 1].""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n1 2 SWAP\n1 CELLS DATA + !\n0 CELLS DATA + !" + "\\! kernel main\n\\! param DATA i32[256]\n1 2 SWAP\n1 CELLS DATA + !\n0 CELLS DATA + !" ), output_count=2, ) @@ -86,7 +86,7 @@ def test_over(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param DATA i64[256]\n" + "\\! param DATA i32[256]\n" "1 2 OVER\n" "2 CELLS DATA + !\n" "1 CELLS DATA + !\n" @@ -102,7 +102,7 @@ def test_rot(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param DATA i64[256]\n" + "\\! param DATA i32[256]\n" "1 2 3 ROT\n" "2 CELLS DATA + !\n" "1 CELLS DATA + !\n" @@ -116,7 +116,7 @@ def test_rot(kernel_runner: KernelRunner) -> None: def test_drop(kernel_runner: KernelRunner) -> None: """DROP removes top: 1 2 DROP → [1].""" result = kernel_runner.run( - forth_source=("\\! kernel main\n\\! param DATA i64[256]\n1 2 DROP\n0 CELLS DATA + !"), + forth_source=("\\! kernel main\n\\! param DATA i32[256]\n1 2 DROP\n0 CELLS DATA + !"), ) assert result[0] == 1 @@ -128,7 +128,7 @@ def test_comparisons(kernel_runner: KernelRunner) -> None: """Test =, <, >, 0= in a single kernel. True = -1, False = 0.""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n" + "\\! kernel main\n\\! param DATA i32[256]\n" "5 5 = 0 CELLS DATA + !\n" "3 5 < 1 CELLS DATA + !\n" "5 3 > 2 CELLS DATA + !\n" @@ -147,7 +147,7 @@ def test_if_else_then(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param DATA i64[256]\n" + "\\! param DATA i32[256]\n" "0 CELLS DATA + @\n" "0 >\n" "IF 1 ELSE 2 THEN\n" @@ -163,7 +163,7 @@ def test_begin_until(kernel_runner: KernelRunner) -> None: """BEGIN/UNTIL countdown: 10 BEGIN 1- DUP 0= UNTIL → final value is 0.""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n10 BEGIN 1 - DUP 0= UNTIL\n0 CELLS DATA + !" + "\\! kernel main\n\\! param DATA i32[256]\n10 BEGIN 1 - DUP 0= UNTIL\n0 CELLS DATA + !" ), ) assert result[0] == 0 @@ -173,7 +173,7 @@ def test_do_loop(kernel_runner: KernelRunner) -> None: """DO/LOOP: write I values 0..4 to DATA[0..4].""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n5 0 DO\n I I CELLS DATA + !\nLOOP" + "\\! kernel main\n\\! param DATA i32[256]\n5 0 DO\n I I CELLS DATA + !\nLOOP" ), output_count=5, ) @@ -185,7 +185,7 @@ def test_do_plus_loop(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param DATA i64[256]\n" + "\\! param DATA i32[256]\n" "0\n" "10 0 DO\n" " I OVER CELLS DATA + !\n" @@ -203,7 +203,7 @@ def test_do_plus_loop_negative(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param DATA i64[256]\n" + "\\! param DATA i32[256]\n" "0\n" "0 10 DO\n" " I OVER CELLS DATA + !\n" @@ -224,7 +224,7 @@ def test_multi_while(kernel_runner: KernelRunner) -> None: """ result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n" + "\\! kernel main\n\\! param DATA i32[256]\n" "20 BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT THEN\n" "0 CELLS DATA + !" ), @@ -241,7 +241,7 @@ def test_while_until(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param DATA i64[256]\n" + "\\! param DATA i32[256]\n" "10 BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN\n" "0 CELLS DATA + !" ), @@ -255,7 +255,7 @@ def test_while_until(kernel_runner: KernelRunner) -> None: def test_global_id(kernel_runner: KernelRunner) -> None: """4 threads each write GLOBAL-ID to DATA[GLOBAL-ID].""" result = kernel_runner.run( - forth_source=("\\! kernel main\n\\! param DATA i64[256]\nGLOBAL-ID\nDUP CELLS DATA + !"), + forth_source=("\\! kernel main\n\\! param DATA i32[256]\nGLOBAL-ID\nDUP CELLS DATA + !"), block=(4, 1, 1), output_count=4, ) @@ -266,8 +266,8 @@ def test_multi_param(kernel_runner: KernelRunner) -> None: """Two params: each thread reads INPUT[i], doubles it, writes OUTPUT[i].""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param INPUT i64[4]\n" - "\\! param OUTPUT i64[4]\n" + "\\! kernel main\n\\! param INPUT i32[4]\n" + "\\! param OUTPUT i32[4]\n" "GLOBAL-ID\n" "DUP CELLS INPUT + @\n" "DUP +\n" @@ -285,9 +285,9 @@ def test_scalar_param(kernel_runner: KernelRunner) -> None: """Scalar + array params: each thread multiplies INPUT[i] by SCALE, writes OUTPUT[i].""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param SCALE i64\n" - "\\! param INPUT i64[4]\n" - "\\! param OUTPUT i64[4]\n" + "\\! kernel main\n\\! param SCALE i32\n" + "\\! param INPUT i32[4]\n" + "\\! param OUTPUT i32[4]\n" "GLOBAL-ID\n" "DUP CELLS INPUT + @\n" "SCALE *\n" @@ -304,15 +304,15 @@ def test_scalar_param(kernel_runner: KernelRunner) -> None: # --- Matmul --- -def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None: - """Naive i64 matmul: C = A(2x4) * B(4x3) -> C(2x3).""" +def test_naive_matmul_i32(kernel_runner: KernelRunner) -> None: + """Naive i32 matmul: C = A(2x4) * B(4x3) -> C(2x3).""" # Work partition: one thread per output element. # GLOBAL-ID maps to (row, col) with row = gid / N, col = gid MOD N. result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param A i64[8]\n" - "\\! param B i64[12]\n" - "\\! param C i64[6]\n" + "\\! kernel main\n\\! param A i32[8]\n" + "\\! param B i32[12]\n" + "\\! param C i32[6]\n" "GLOBAL-ID\n" "DUP 3 /\n" "SWAP 3 MOD\n" @@ -338,8 +338,8 @@ def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None: assert result == [12, 6, 9, 28, 14, 29] -def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None: - """Tiled i64 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4). +def test_tiled_matmul_i32(kernel_runner: KernelRunner) -> None: + """Tiled i32 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4). Uses 2x2 tiles, shared memory for A/B tiles, and BARRIER for sync. Grid: (2,2,1), Block: (2,2,1) — 4 blocks of 4 threads each. @@ -347,11 +347,11 @@ def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param A i64[16]\n" - "\\! param B i64[16]\n" - "\\! param C i64[16]\n" - "\\! shared SA i64[4]\n" - "\\! shared SB i64[4]\n" + "\\! param A i32[16]\n" + "\\! param B i32[16]\n" + "\\! param C i32[16]\n" + "\\! shared SA i32[4]\n" + "\\! shared SB i32[4]\n" "BID-Y 2 * TID-Y +\n" "BID-X 2 * TID-X +\n" "0\n" @@ -400,8 +400,8 @@ def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None: assert result == expected -def test_tiled_matmul_f64(kernel_runner: KernelRunner) -> None: - """Tiled f64 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4). +def test_tiled_matmul_f32(kernel_runner: KernelRunner) -> None: + """Tiled f32 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4). Uses 2x2 tiles, float shared memory for A/B tiles, and BARRIER for sync. Grid: (2,2,1), Block: (2,2,1) — 4 blocks of 4 threads each. @@ -409,11 +409,11 @@ def test_tiled_matmul_f64(kernel_runner: KernelRunner) -> None: result = kernel_runner.run( forth_source=( "\\! kernel main\n" - "\\! param A f64[16]\n" - "\\! param B f64[16]\n" - "\\! param C f64[16]\n" - "\\! shared SA f64[4]\n" - "\\! shared SB f64[4]\n" + "\\! param A f32[16]\n" + "\\! param B f32[16]\n" + "\\! param C f32[16]\n" + "\\! shared SA f32[4]\n" + "\\! shared SB f32[4]\n" "BID-Y 2 * TID-Y +\n" "BID-X 2 * TID-X +\n" "0.0\n" @@ -469,7 +469,7 @@ def test_user_defined_word(kernel_runner: KernelRunner) -> None: """: DOUBLE DUP + ; then 5 DOUBLE → 10.""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n: DOUBLE DUP + ;\n5 DOUBLE\n0 CELLS DATA + !" + "\\! kernel main\n\\! param DATA i32[256]\n: DOUBLE DUP + ;\n5 DOUBLE\n0 CELLS DATA + !" ), ) assert result[0] == 10 @@ -481,7 +481,7 @@ def test_user_defined_word(kernel_runner: KernelRunner) -> None: def test_float_addition(kernel_runner: KernelRunner) -> None: """F+: 1.5 + 2.5 = 4.0.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA f64[256]\n1.5 2.5 F+\n0 CELLS DATA + F!", + forth_source="\\! kernel main\n\\! param DATA f32[256]\n1.5 2.5 F+\n0 CELLS DATA + F!", ) assert result[0] == pytest.approx(4.0) @@ -489,7 +489,7 @@ def test_float_addition(kernel_runner: KernelRunner) -> None: def test_float_subtraction(kernel_runner: KernelRunner) -> None: """F-: 10.0 - 3.5 = 6.5.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA f64[256]\n10.0 3.5 F-\n0 CELLS DATA + F!", + forth_source="\\! kernel main\n\\! param DATA f32[256]\n10.0 3.5 F-\n0 CELLS DATA + F!", ) assert result[0] == pytest.approx(6.5) @@ -497,7 +497,7 @@ def test_float_subtraction(kernel_runner: KernelRunner) -> None: def test_float_multiplication(kernel_runner: KernelRunner) -> None: """F*: 6.0 * 7.5 = 45.0.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA f64[256]\n6.0 7.5 F*\n0 CELLS DATA + F!", + forth_source="\\! kernel main\n\\! param DATA f32[256]\n6.0 7.5 F*\n0 CELLS DATA + F!", ) assert result[0] == pytest.approx(45.0) @@ -505,7 +505,7 @@ def test_float_multiplication(kernel_runner: KernelRunner) -> None: def test_float_division(kernel_runner: KernelRunner) -> None: """F/: 42.0 / 6.0 = 7.0.""" result = kernel_runner.run( - forth_source="\\! kernel main\n\\! param DATA f64[256]\n42.0 6.0 F/\n0 CELLS DATA + F!", + forth_source="\\! kernel main\n\\! param DATA f32[256]\n42.0 6.0 F/\n0 CELLS DATA + F!", ) assert result[0] == pytest.approx(7.0) @@ -517,7 +517,7 @@ def test_float_load_store(kernel_runner: KernelRunner) -> None: """F@ and F!: read from DATA[0], multiply by 2, write to DATA[1].""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA f64[256]\n0 CELLS DATA + F@\n2.0 F*\n1 CELLS DATA + F!" + "\\! kernel main\n\\! param DATA f32[256]\n0 CELLS DATA + F@\n2.0 F*\n1 CELLS DATA + F!" ), params={"DATA": [3.14]}, output_count=2, @@ -529,10 +529,10 @@ def test_float_load_store(kernel_runner: KernelRunner) -> None: def test_float_scalar_param(kernel_runner: KernelRunner) -> None: - """Scalar f64 param: each thread scales DATA[i] by SCALE.""" + """Scalar f32 param: each thread scales DATA[i] by SCALE.""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA f64[256]\n\\! param SCALE f64\n" + "\\! kernel main\n\\! param DATA f32[256]\n\\! param SCALE f32\n" "GLOBAL-ID\n" "DUP CELLS DATA + F@\n" "SCALE F*\n" @@ -554,10 +554,10 @@ def test_float_scalar_param(kernel_runner: KernelRunner) -> None: def test_float_comparisons(kernel_runner: KernelRunner) -> None: - """F=, F<, F>: True = -1, False = 0 (pushed as i64 on the stack).""" + """F=, F<, F>: True = -1, False = 0 (pushed as i32 on the stack).""" result = kernel_runner.run( forth_source=( - "\\! kernel main\n\\! param DATA i64[256]\n" + "\\! kernel main\n\\! param DATA i32[256]\n" "3.14 3.14 F= 0 CELLS DATA + !\n" "1.0 2.0 F< 1 CELLS DATA + !\n" "5.0 3.0 F> 2 CELLS DATA + !" @@ -571,17 +571,17 @@ def test_float_comparisons(kernel_runner: KernelRunner) -> None: def test_int_to_float_conversion(kernel_runner: KernelRunner) -> None: - """S>F: convert int 7 to float, multiply by 1.5, store as f64.""" + """S>F: convert int 7 to float, multiply by 1.5, store as f32.""" result = kernel_runner.run( - forth_source=("\\! kernel main\n\\! param DATA f64[256]\n7 S>F 1.5 F*\n0 CELLS DATA + F!"), + forth_source=("\\! kernel main\n\\! param DATA f32[256]\n7 S>F 1.5 F*\n0 CELLS DATA + F!"), ) assert result[0] == pytest.approx(10.5) def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None: - """F>S: convert float 7.9 to int (truncates to 7), store as i64.""" + """F>S: convert float 7.9 to int (truncates to 7), store as i32.""" result = kernel_runner.run( - forth_source=("\\! kernel main\n\\! param DATA i64[256]\n7.9 F>S\n0 CELLS DATA + !"), + forth_source=("\\! kernel main\n\\! param DATA i32[256]\n7.9 F>S\n0 CELLS DATA + !"), ) assert result[0] == 7 @@ -590,14 +590,14 @@ def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None: _ATTENTION_KERNEL = """\ \\! kernel attention -\\! param Q f64[{n}] -\\! param K f64[{n}] -\\! param V f64[{n}] -\\! param O f64[{n}] -\\! param SEQ_LEN i64 -\\! param HEAD_DIM i64 -\\! shared SCORES f64[{seq_len}] -\\! shared SCRATCH f64[{seq_len}] +\\! param Q f32[{n}] +\\! param K f32[{n}] +\\! param V f32[{n}] +\\! param O f32[{n}] +\\! param SEQ_LEN i32 +\\! param HEAD_DIM i32 +\\! shared SCORES f32[{seq_len}] +\\! shared SCRATCH f32[{seq_len}] BID-X TID-X 0.0 @@ -648,17 +648,20 @@ def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None: def _attention_reference(q: np.ndarray, k: np.ndarray, v: np.ndarray, seq_len: int) -> list[float]: - """Compute scaled dot-product attention with causal mask (NumPy reference).""" + """Compute scaled dot-product attention with causal mask (NumPy reference, f32).""" + q = q.astype(np.float32) + k = k.astype(np.float32) + v = v.astype(np.float32) head_dim = q.shape[1] - scores = q @ k.T / np.sqrt(head_dim) + scores = q @ k.T / np.sqrt(np.float32(head_dim)) causal_mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1) - scores[causal_mask] = -1e30 + scores[causal_mask] = np.float32(-1e30) exp_scores = np.exp(scores - scores.max(axis=1, keepdims=True)) attn = exp_scores / exp_scores.sum(axis=1, keepdims=True) return (attn @ v).flatten().tolist() -def test_naive_attention_f64(kernel_runner: KernelRunner) -> None: +def test_naive_attention_f32(kernel_runner: KernelRunner) -> None: """Naive scaled dot-product attention with causal mask. O = softmax(Q @ K^T / sqrt(d_k)) @ V, seq_len=4, head_dim=4. @@ -708,10 +711,10 @@ def test_naive_attention_f64(kernel_runner: KernelRunner) -> None: output_param=3, output_count=n, ) - assert result == [pytest.approx(v) for v in expected] + assert result == [pytest.approx(v, rel=1e-4) for v in expected] -def test_naive_attention_f64_16x64(kernel_runner: KernelRunner) -> None: +def test_naive_attention_f32_16x64(kernel_runner: KernelRunner) -> None: """Naive scaled dot-product attention, seq_len=16, head_dim=64.""" seq_len, head_dim = 16, 64 @@ -737,4 +740,4 @@ def test_naive_attention_f64_16x64(kernel_runner: KernelRunner) -> None: output_param=3, output_count=n, ) - assert result == [pytest.approx(v) for v in expected] + assert result == [pytest.approx(v, rel=1e-3) for v in expected] diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index e1b0fa1..42fea1c 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -124,7 +124,7 @@ def Forth_ConstantOp : Forth_Op<"constant", [Pure]> { Forth semantics: ( -- n ) }]; - let arguments = (ins Forth_StackType:$input_stack, AnyAttrOf<[I64Attr, F64Attr]>:$value); + let arguments = (ins Forth_StackType:$input_stack, AnyAttrOf<[I32Attr, F32Attr]>:$value); let results = (outs Forth_StackType:$output_stack); let assemblyFormat = [{ @@ -183,7 +183,7 @@ def Forth_ModOp : Forth_StackOpBase<"mod"> { def Forth_AddFOp : Forth_StackOpBase<"addf"> { let summary = "Add top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, adds, bitcasts result back to i64. + Pops two i64 values, truncates to i32, bitcasts to f32, adds, bitcasts result back to i32, extends to i64. Forth semantics: ( f1 f2 -- f1+f2 ) }]; } @@ -191,7 +191,7 @@ def Forth_AddFOp : Forth_StackOpBase<"addf"> { def Forth_SubFOp : Forth_StackOpBase<"subf"> { let summary = "Subtract top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, subtracts, bitcasts result back to i64. + Pops two i64 values, truncates to i32, bitcasts to f32, subtracts, bitcasts result back to i32, extends to i64. Forth semantics: ( f1 f2 -- f1-f2 ) }]; } @@ -199,7 +199,7 @@ def Forth_SubFOp : Forth_StackOpBase<"subf"> { def Forth_MulFOp : Forth_StackOpBase<"mulf"> { let summary = "Multiply top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, multiplies, bitcasts result back to i64. + Pops two i64 values, truncates to i32, bitcasts to f32, multiplies, bitcasts result back to i32, extends to i64. Forth semantics: ( f1 f2 -- f1*f2 ) }]; } @@ -207,7 +207,7 @@ def Forth_MulFOp : Forth_StackOpBase<"mulf"> { def Forth_DivFOp : Forth_StackOpBase<"divf"> { let summary = "Divide top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, divides, bitcasts result back to i64. + Pops two i64 values, truncates to i32, bitcasts to f32, divides, bitcasts result back to i32, extends to i64. Forth semantics: ( f1 f2 -- f1/f2 ) }]; } @@ -219,8 +219,8 @@ def Forth_DivFOp : Forth_StackOpBase<"divf"> { def Forth_ExpFOp : Forth_StackOpBase<"expf"> { let summary = "Exponential of top stack element (float)"; let description = [{ - Pops an i64 (f64 bit pattern), bitcasts to f64, computes e^x, - bitcasts result back to i64. + Pops an i64, truncates to i32 (f32 bit pattern), bitcasts to f32, computes e^x, + bitcasts result back to i32, extends to i64. Forth semantics: ( f -- exp(f) ) }]; } @@ -228,8 +228,8 @@ def Forth_ExpFOp : Forth_StackOpBase<"expf"> { def Forth_SqrtFOp : Forth_StackOpBase<"sqrtf"> { let summary = "Square root of top stack element (float)"; let description = [{ - Pops an i64 (f64 bit pattern), bitcasts to f64, computes sqrt(x), - bitcasts result back to i64. + Pops an i64, truncates to i32 (f32 bit pattern), bitcasts to f32, computes sqrt(x), + bitcasts result back to i32, extends to i64. Forth semantics: ( f -- sqrt(f) ) }]; } @@ -237,8 +237,8 @@ def Forth_SqrtFOp : Forth_StackOpBase<"sqrtf"> { def Forth_LogFOp : Forth_StackOpBase<"logf"> { let summary = "Natural logarithm of top stack element (float)"; let description = [{ - Pops an i64 (f64 bit pattern), bitcasts to f64, computes ln(x), - bitcasts result back to i64. + Pops an i64, truncates to i32 (f32 bit pattern), bitcasts to f32, computes ln(x), + bitcasts result back to i32, extends to i64. Forth semantics: ( f -- log(f) ) }]; } @@ -246,8 +246,8 @@ def Forth_LogFOp : Forth_StackOpBase<"logf"> { def Forth_AbsFOp : Forth_StackOpBase<"absf"> { let summary = "Absolute value of top stack element (float)"; let description = [{ - Pops an i64 (f64 bit pattern), bitcasts to f64, computes |x|, - bitcasts result back to i64. + Pops an i64, truncates to i32 (f32 bit pattern), bitcasts to f32, computes |x|, + bitcasts result back to i32, extends to i64. Forth semantics: ( f -- |f| ) }]; } @@ -255,8 +255,8 @@ def Forth_AbsFOp : Forth_StackOpBase<"absf"> { def Forth_NegFOp : Forth_StackOpBase<"negf"> { let summary = "Negate top stack element (float)"; let description = [{ - Pops an i64 (f64 bit pattern), bitcasts to f64, negates, - bitcasts result back to i64. + Pops an i64, truncates to i32 (f32 bit pattern), bitcasts to f32, negates, + bitcasts result back to i32, extends to i64. Forth semantics: ( f -- -f ) }]; } @@ -264,7 +264,7 @@ def Forth_NegFOp : Forth_StackOpBase<"negf"> { def Forth_MaxFOp : Forth_StackOpBase<"maxf"> { let summary = "Maximum of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, computes max, bitcasts result back to i64. + Pops two i64 values, truncates to i32, bitcasts to f32, computes max, bitcasts result back to i32, extends to i64. Forth semantics: ( f1 f2 -- max(f1,f2) ) }]; } @@ -272,7 +272,7 @@ def Forth_MaxFOp : Forth_StackOpBase<"maxf"> { def Forth_MinFOp : Forth_StackOpBase<"minf"> { let summary = "Minimum of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, computes min, bitcasts result back to i64. + Pops two i64 values, truncates to i32, bitcasts to f32, computes min, bitcasts result back to i32, extends to i64. Forth semantics: ( f1 f2 -- min(f1,f2) ) }]; } @@ -334,72 +334,220 @@ def Forth_RshiftOp : Forth_StackOpBase<"rshift"> { //===----------------------------------------------------------------------===// def Forth_LoadIOp : Forth_StackOpBase<"loadi"> { - let summary = "Load i64 value from memory buffer"; + let summary = "Load i32 value from memory buffer"; let description = [{ - Pops an address from the stack, loads an i64 value from memory, - and pushes the loaded value onto the stack. + Pops an address from the stack, loads an i32 value from memory, + sign-extends to i64, and pushes the loaded value onto the stack. Forth semantics: ( addr -- value ) }]; } def Forth_StoreIOp : Forth_StackOpBase<"storei"> { - let summary = "Store i64 value to memory buffer"; + let summary = "Store i32 value to memory buffer"; let description = [{ - Pops an address and value from the stack, stores the i64 value to memory. + Pops an address and value from the stack, truncates to i32, stores to memory. Forth semantics: ( x addr -- ) }]; } def Forth_LoadFOp : Forth_StackOpBase<"loadf"> { - let summary = "Load f64 value from memory buffer"; + let summary = "Load f32 value from memory buffer"; let description = [{ - Pops an address from the stack, loads an f64 value from memory, - bitcasts to i64, and pushes onto the stack. + Pops an address from the stack, loads an f32 value from memory, + bitcasts to i32, sign-extends to i64, and pushes onto the stack. Forth semantics: ( addr -- value ) }]; } def Forth_StoreFOp : Forth_StackOpBase<"storef"> { - let summary = "Store f64 value to memory buffer"; + let summary = "Store f32 value to memory buffer"; let description = [{ - Pops an address and value (i64 bit pattern of f64) from the stack, - bitcasts to f64, stores to memory. + Pops an address and value from the stack, truncates to i32, + bitcasts to f32, stores to memory. Forth semantics: ( x addr -- ) }]; } def Forth_SharedLoadIOp : Forth_StackOpBase<"shared_loadi"> { - let summary = "Load i64 value from shared memory buffer"; + let summary = "Load i32 value from shared memory buffer"; let description = [{ - Pops an address from the stack, loads an i64 value from shared/workgroup memory, - and pushes the loaded value onto the stack. + Pops an address from the stack, loads an i32 value from shared/workgroup memory, + sign-extends to i64, and pushes the loaded value onto the stack. Forth semantics: ( addr -- value ) }]; } def Forth_SharedStoreIOp : Forth_StackOpBase<"shared_storei"> { - let summary = "Store i64 value to shared memory buffer"; + let summary = "Store i32 value to shared memory buffer"; let description = [{ - Pops an address and value from the stack, stores the i64 value to + Pops an address and value from the stack, truncates to i32, stores to shared/workgroup memory. Forth semantics: ( x addr -- ) }]; } def Forth_SharedLoadFOp : Forth_StackOpBase<"shared_loadf"> { - let summary = "Load f64 value from shared memory buffer"; + let summary = "Load f32 value from shared memory buffer"; let description = [{ - Pops an address from the stack, loads an f64 value from shared/workgroup memory, - bitcasts to i64, and pushes onto the stack. + Pops an address from the stack, loads an f32 value from shared/workgroup memory, + bitcasts to i32, sign-extends to i64, and pushes onto the stack. Forth semantics: ( addr -- value ) }]; } def Forth_SharedStoreFOp : Forth_StackOpBase<"shared_storef"> { - let summary = "Store f64 value to shared memory buffer"; + let summary = "Store f32 value to shared memory buffer"; let description = [{ - Pops an address and value (i64 bit pattern of f64) from the stack, - bitcasts to f64, stores to shared/workgroup memory. + Pops an address and value from the stack, truncates to i32, + bitcasts to f32, stores to shared/workgroup memory. + Forth semantics: ( x addr -- ) + }]; +} + +//===----------------------------------------------------------------------===// +// Reduced-precision memory operations. +//===----------------------------------------------------------------------===// + +def Forth_LoadHFOp : Forth_StackOpBase<"loadhf"> { + let summary = "Load f16 value from memory buffer"; + let description = [{ + Pops an address from the stack, loads an f16 value from memory, + extends to f32, bitcasts to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_StoreHFOp : Forth_StackOpBase<"storehf"> { + let summary = "Store f16 value to memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + bitcasts to f32, truncates to f16, stores to memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_LoadBFOp : Forth_StackOpBase<"loadbf"> { + let summary = "Load bf16 value from memory buffer"; + let description = [{ + Pops an address from the stack, loads a bf16 value from memory, + extends to f32, bitcasts to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_StoreBFOp : Forth_StackOpBase<"storebf"> { + let summary = "Store bf16 value to memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + bitcasts to f32, truncates to bf16, stores to memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_LoadI8Op : Forth_StackOpBase<"loadi8"> { + let summary = "Load i8 value from memory buffer"; + let description = [{ + Pops an address from the stack, loads an i8 value from memory, + sign-extends to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_StoreI8Op : Forth_StackOpBase<"storei8"> { + let summary = "Store i8 value to memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + truncates to i8, stores to memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_LoadI16Op : Forth_StackOpBase<"loadi16"> { + let summary = "Load i16 value from memory buffer"; + let description = [{ + Pops an address from the stack, loads an i16 value from memory, + sign-extends to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_StoreI16Op : Forth_StackOpBase<"storei16"> { + let summary = "Store i16 value to memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + truncates to i16, stores to memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_SharedLoadHFOp : Forth_StackOpBase<"shared_loadhf"> { + let summary = "Load f16 value from shared memory buffer"; + let description = [{ + Pops an address from the stack, loads an f16 value from shared memory, + extends to f32, bitcasts to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_SharedStoreHFOp : Forth_StackOpBase<"shared_storehf"> { + let summary = "Store f16 value to shared memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + bitcasts to f32, truncates to f16, stores to shared memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_SharedLoadBFOp : Forth_StackOpBase<"shared_loadbf"> { + let summary = "Load bf16 value from shared memory buffer"; + let description = [{ + Pops an address from the stack, loads a bf16 value from shared memory, + extends to f32, bitcasts to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_SharedStoreBFOp : Forth_StackOpBase<"shared_storebf"> { + let summary = "Store bf16 value to shared memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + bitcasts to f32, truncates to bf16, stores to shared memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_SharedLoadI8Op : Forth_StackOpBase<"shared_loadi8"> { + let summary = "Load i8 value from shared memory buffer"; + let description = [{ + Pops an address from the stack, loads an i8 value from shared memory, + sign-extends to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_SharedStoreI8Op : Forth_StackOpBase<"shared_storei8"> { + let summary = "Store i8 value to shared memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + truncates to i8, stores to shared memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_SharedLoadI16Op : Forth_StackOpBase<"shared_loadi16"> { + let summary = "Load i16 value from shared memory buffer"; + let description = [{ + Pops an address from the stack, loads an i16 value from shared memory, + sign-extends to i32, sign-extends to i64, and pushes. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_SharedStoreI16Op : Forth_StackOpBase<"shared_storei16"> { + let summary = "Store i16 value to shared memory buffer"; + let description = [{ + Pops an address and value from the stack, truncates to i32, + truncates to i16, stores to shared memory. Forth semantics: ( x addr -- ) }]; } @@ -611,7 +759,7 @@ def Forth_GeIOp : Forth_StackOpBase<"gei"> { def Forth_EqFOp : Forth_StackOpBase<"eqf"> { let summary = "Test equality of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, compares for equality (ordered). + Pops two i64 values, truncates to i32, bitcasts to f32, compares for equality (ordered). Pushes -1 (true) or 0 (false). Forth semantics: ( f1 f2 -- flag ) }]; @@ -620,7 +768,7 @@ def Forth_EqFOp : Forth_StackOpBase<"eqf"> { def Forth_LtFOp : Forth_StackOpBase<"ltf"> { let summary = "Test less-than of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, compares f1 < f2 (ordered). + Pops two i64 values, truncates to i32, bitcasts to f32, compares f1 < f2 (ordered). Pushes -1 (true) or 0 (false). Forth semantics: ( f1 f2 -- flag ) }]; @@ -629,7 +777,7 @@ def Forth_LtFOp : Forth_StackOpBase<"ltf"> { def Forth_GtFOp : Forth_StackOpBase<"gtf"> { let summary = "Test greater-than of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, compares f1 > f2 (ordered). + Pops two i64 values, truncates to i32, bitcasts to f32, compares f1 > f2 (ordered). Pushes -1 (true) or 0 (false). Forth semantics: ( f1 f2 -- flag ) }]; @@ -638,7 +786,7 @@ def Forth_GtFOp : Forth_StackOpBase<"gtf"> { def Forth_NeFOp : Forth_StackOpBase<"nef"> { let summary = "Test inequality of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, compares for inequality (ordered). + Pops two i64 values, truncates to i32, bitcasts to f32, compares for inequality (ordered). Pushes -1 (true) or 0 (false). Forth semantics: ( f1 f2 -- flag ) }]; @@ -647,7 +795,7 @@ def Forth_NeFOp : Forth_StackOpBase<"nef"> { def Forth_LeFOp : Forth_StackOpBase<"lef"> { let summary = "Test less-than-or-equal of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, compares f1 <= f2 (ordered). + Pops two i64 values, truncates to i32, bitcasts to f32, compares f1 <= f2 (ordered). Pushes -1 (true) or 0 (false). Forth semantics: ( f1 f2 -- flag ) }]; @@ -656,7 +804,7 @@ def Forth_LeFOp : Forth_StackOpBase<"lef"> { def Forth_GeFOp : Forth_StackOpBase<"gef"> { let summary = "Test greater-than-or-equal of top two stack elements (float)"; let description = [{ - Pops two i64 values, bitcasts to f64, compares f1 >= f2 (ordered). + Pops two i64 values, truncates to i32, bitcasts to f32, compares f1 >= f2 (ordered). Pushes -1 (true) or 0 (false). Forth semantics: ( f1 f2 -- flag ) }]; @@ -677,8 +825,8 @@ def Forth_ZeroEqOp : Forth_StackOpBase<"zero_eq"> { def Forth_IToFOp : Forth_StackOpBase<"itof"> { let summary = "Convert integer to float"; let description = [{ - Pops an i64 integer, converts to f64 via sitofp, bitcasts result - to i64 bit pattern, pushes onto stack. + Pops an i64, truncates to i32, converts to f32 via sitofp, bitcasts + f32 to i32, sign-extends to i64, pushes onto stack. Forth semantics: ( n -- f ) }]; } @@ -686,8 +834,8 @@ def Forth_IToFOp : Forth_StackOpBase<"itof"> { def Forth_FToIOp : Forth_StackOpBase<"ftoi"> { let summary = "Convert float to integer"; let description = [{ - Pops an i64 (f64 bit pattern), bitcasts to f64, converts to i64 - via fptosi, pushes onto stack. + Pops an i64, truncates to i32 (f32 bit pattern), bitcasts to f32, + converts to i32 via fptosi, sign-extends to i64, pushes onto stack. Forth semantics: ( f -- n ) }]; } diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index 550b77e..781d31f 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -107,15 +107,19 @@ struct ConstantOpConversion : public OpConversionPattern { Value valueToPush; auto typedValue = cast(op.getValueAttr()); if (isa(typedValue)) { - // Float: create f64 constant, bitcast to i64 - Value f64Value = rewriter.create( - loc, rewriter.getF64Type(), typedValue); - valueToPush = rewriter.create( - loc, rewriter.getI64Type(), f64Value); + // Float: create f32 constant, bitcast to i32, sext to i64 + Value f32Value = rewriter.create( + loc, rewriter.getF32Type(), typedValue); + Value i32Value = rewriter.create( + loc, rewriter.getI32Type(), f32Value); + valueToPush = + rewriter.create(loc, rewriter.getI64Type(), i32Value); } else { - // Integer: create i64 constant directly - valueToPush = rewriter.create( - loc, rewriter.getI64Type(), typedValue); + // Integer: create i32 constant, sext to i64 + Value i32Value = rewriter.create( + loc, rewriter.getI32Type(), typedValue); + valueToPush = + rewriter.create(loc, rewriter.getI64Type(), i32Value); } Value newSP = pushValue(loc, rewriter, memref, stackPtr, valueToPush); @@ -444,7 +448,9 @@ struct RollOpConversion : public OpConversionPattern { /// Base template for binary arithmetic operations. /// Pops two values, applies operation, pushes result: (a b -- result) -/// When IsFloat=true, bitcasts i64->f64 before the op and f64->i64 after. +/// Arithmetic is performed in i32/f32; values are truncated from and +/// sign-extended back to i64 at stack boundaries. +/// When IsFloat=true, bitcasts i32->f32 before the op and f32->i32 after. template struct BinaryArithOpConversion : public OpConversionPattern { BinaryArithOpConversion(const TypeConverter &typeConverter, @@ -463,25 +469,33 @@ struct BinaryArithOpConversion : public OpConversionPattern { Value one = rewriter.create(loc, 1); - // Load top two values (b at SP, a at SP-1) - Value b = rewriter.create(loc, memref, stackPtr); + // Load top two values (b at SP, a at SP-1) as i64 + Value bI64 = rewriter.create(loc, memref, stackPtr); Value spMinus1 = rewriter.create(loc, stackPtr, one); - Value a = rewriter.create(loc, memref, spMinus1); + Value aI64 = rewriter.create(loc, memref, spMinus1); + + // Truncate to i32 + auto i32Type = rewriter.getI32Type(); + Value aI32 = rewriter.create(loc, i32Type, aI64); + Value bI32 = rewriter.create(loc, i32Type, bI64); - Value result; + Value resultI32; if constexpr (IsFloat) { - // Bitcast i64 -> f64 - auto f64Type = rewriter.getF64Type(); - Value aF = rewriter.create(loc, f64Type, a); - Value bF = rewriter.create(loc, f64Type, b); + // Bitcast i32 -> f32 + auto f32Type = rewriter.getF32Type(); + Value aF = rewriter.create(loc, f32Type, aI32); + Value bF = rewriter.create(loc, f32Type, bI32); Value resF = rewriter.create(loc, aF, bF); - // Bitcast f64 -> i64 - result = - rewriter.create(loc, rewriter.getI64Type(), resF); + // Bitcast f32 -> i32 + resultI32 = rewriter.create(loc, i32Type, resF); } else { - result = rewriter.create(loc, a, b); + resultI32 = rewriter.create(loc, aI32, bI32); } + // Sign-extend i32 -> i64 for stack storage + Value result = + rewriter.create(loc, rewriter.getI64Type(), resultI32); + // Store result at SP-1 (effectively popping both and pushing result) rewriter.create(loc, result, memref, spMinus1); @@ -523,7 +537,8 @@ using MinFOpConversion = /// Base template for unary float operations. /// Pops one value, applies operation, pushes result: (f -- result) -/// Bitcasts i64->f64 before the op and f64->i64 after. +/// Truncates i64->i32, bitcasts i32->f32 before the op, then +/// bitcasts f32->i32, sign-extends i32->i64 after. template struct UnaryFloatOpConversion : public OpConversionPattern { UnaryFloatOpConversion(const TypeConverter &typeConverter, @@ -540,19 +555,22 @@ struct UnaryFloatOpConversion : public OpConversionPattern { Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - // Load value from top of stack - Value a = rewriter.create(loc, memref, stackPtr); + // Load i64 value from top of stack, truncate to i32 + Value aI64 = rewriter.create(loc, memref, stackPtr); + auto i32Type = rewriter.getI32Type(); + Value aI32 = rewriter.create(loc, i32Type, aI64); - // Bitcast i64 -> f64 - auto f64Type = rewriter.getF64Type(); - Value aF = rewriter.create(loc, f64Type, a); + // Bitcast i32 -> f32 + auto f32Type = rewriter.getF32Type(); + Value aF = rewriter.create(loc, f32Type, aI32); // Apply math/arith op Value resF = rewriter.create(loc, aF); - // Bitcast f64 -> i64 + // Bitcast f32 -> i32, sign-extend to i64 + Value resI32 = rewriter.create(loc, i32Type, resF); Value result = - rewriter.create(loc, rewriter.getI64Type(), resF); + rewriter.create(loc, rewriter.getI64Type(), resI32); // Store result at same position (SP unchanged — unary op) rewriter.create(loc, result, memref, stackPtr); @@ -570,8 +588,9 @@ using AbsFOpConversion = UnaryFloatOpConversion; using NegFOpConversion = UnaryFloatOpConversion; /// Base template for binary comparison operations. -/// Pops two values, compares, pushes -1 (true) or 0 (false): (a b -- flag) -/// When IsFloat=true, bitcasts i64->f64 before comparing. +/// Pops two values, compares in i32/f32, pushes -1 (true) or 0 (false): (a b -- +/// flag) When IsFloat=true, truncates to i32, bitcasts i32->f32 before +/// comparing. template struct BinaryCmpOpConversion : public OpConversionPattern { @@ -591,19 +610,24 @@ struct BinaryCmpOpConversion : public OpConversionPattern { Value one = rewriter.create(loc, 1); - // Load top two values (b at SP, a at SP-1) - Value b = rewriter.create(loc, memref, stackPtr); + // Load top two values (b at SP, a at SP-1) as i64 + Value bI64 = rewriter.create(loc, memref, stackPtr); Value spMinus1 = rewriter.create(loc, stackPtr, one); - Value a = rewriter.create(loc, memref, spMinus1); + Value aI64 = rewriter.create(loc, memref, spMinus1); + + // Truncate to i32 + auto i32Type = rewriter.getI32Type(); + Value aI32 = rewriter.create(loc, i32Type, aI64); + Value bI32 = rewriter.create(loc, i32Type, bI64); Value cmp; if constexpr (IsFloat) { - auto f64Type = rewriter.getF64Type(); - Value aF = rewriter.create(loc, f64Type, a); - Value bF = rewriter.create(loc, f64Type, b); + auto f32Type = rewriter.getF32Type(); + Value aF = rewriter.create(loc, f32Type, aI32); + Value bF = rewriter.create(loc, f32Type, bI32); cmp = rewriter.create(loc, Predicate, aF, bF); } else { - cmp = rewriter.create(loc, Predicate, a, b); + cmp = rewriter.create(loc, Predicate, aI32, bI32); } // Extend i1 to i64: true = -1 (all bits set), false = 0 @@ -662,13 +686,19 @@ struct NotOpConversion : public OpConversionPattern { Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - // Load top value - Value a = rewriter.create(loc, memref, stackPtr); + // Load top value (i64), truncate to i32 + Value aI64 = rewriter.create(loc, memref, stackPtr); + auto i32Type = rewriter.getI32Type(); + Value aI32 = rewriter.create(loc, i32Type, aI64); - // XOR with -1 (all bits set) to flip all bits + // XOR with -1 (all bits set) to flip all bits (i32) Value allOnes = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(-1)); - Value result = rewriter.create(loc, a, allOnes); + loc, i32Type, rewriter.getI32IntegerAttr(-1)); + Value resultI32 = rewriter.create(loc, aI32, allOnes); + + // Sign-extend back to i64 + Value result = + rewriter.create(loc, rewriter.getI64Type(), resultI32); // Store result at same position (SP unchanged) rewriter.create(loc, result, memref, stackPtr); @@ -693,14 +723,16 @@ struct ZeroEqOpConversion : public OpConversionPattern { Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - // Load top value - Value a = rewriter.create(loc, memref, stackPtr); + // Load top value (i64), truncate to i32 + Value aI64 = rewriter.create(loc, memref, stackPtr); + auto i32Type = rewriter.getI32Type(); + Value aI32 = rewriter.create(loc, i32Type, aI64); - // Compare with zero + // Compare with zero (i32) Value zero = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(0)); - Value cmp = - rewriter.create(loc, arith::CmpIPredicate::eq, a, zero); + loc, i32Type, rewriter.getI32IntegerAttr(0)); + Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, + aI32, zero); // Extend i1 to i64: true = -1, false = 0 Value result = @@ -750,18 +782,21 @@ struct ParamRefOpConversion : public OpConversionPattern { Value valueToPush; if (auto memrefType = dyn_cast(memrefArg.getType())) { - // Extract pointer as index, then cast to i64 + // Extract pointer as index, then cast to i64 (pointers stay 64-bit) Value ptrIndex = rewriter.create( loc, memrefArg); valueToPush = rewriter.create( loc, rewriter.getI64Type(), ptrIndex); - } else if (memrefArg.getType().isInteger(64)) { - // Scalar i64 param: push value directly. - valueToPush = memrefArg; - } else if (memrefArg.getType().isF64()) { - // Scalar f64 param: bitcast to i64 for stack storage. - valueToPush = rewriter.create( - loc, rewriter.getI64Type(), memrefArg); + } else if (memrefArg.getType().isInteger(32)) { + // Scalar i32 param: sign-extend to i64 for stack. + valueToPush = rewriter.create(loc, rewriter.getI64Type(), + memrefArg); + } else if (memrefArg.getType().isF32()) { + // Scalar f32 param: bitcast to i32, sign-extend to i64. + Value i32Value = rewriter.create( + loc, rewriter.getI32Type(), memrefArg); + valueToPush = + rewriter.create(loc, rewriter.getI64Type(), i32Value); } else { return rewriter.notifyMatchFailure( op, "unsupported param argument type for param_ref"); @@ -777,7 +812,8 @@ struct ParamRefOpConversion : public OpConversionPattern { /// Generalized memory load template. /// Pops address from stack, loads value via pointer, pushes value. -/// When IsFloat=true, loads f64 from memory and bitcasts to i64 for stack. +/// When IsFloat=true, loads f32 from memory, bitcasts to i32, sext to i64. +/// When IsFloat=false, loads i32, sext to i64. /// AddressSpace selects global (0) or workgroup memory. template struct MemoryLoadOpConversion : public OpConversionPattern { @@ -804,18 +840,21 @@ struct MemoryLoadOpConversion : public OpConversionPattern { // Load value from memory via pointer Value ptr = rewriter.create(loc, ptrType, addrValue); - Value valueToPush; + Value i32Value; if constexpr (IsFloat) { - // Load f64 from memory, then bitcast to i64 for stack storage - Value loadedF64 = - rewriter.create(loc, rewriter.getF64Type(), ptr); - valueToPush = rewriter.create( - loc, rewriter.getI64Type(), loadedF64); + // Load f32 from memory, then bitcast to i32 + Value loadedF32 = + rewriter.create(loc, rewriter.getF32Type(), ptr); + i32Value = rewriter.create(loc, rewriter.getI32Type(), + loadedF32); } else { - valueToPush = - rewriter.create(loc, rewriter.getI64Type(), ptr); + i32Value = rewriter.create(loc, rewriter.getI32Type(), ptr); } + // Sign-extend i32 to i64 for stack storage + Value valueToPush = + rewriter.create(loc, rewriter.getI64Type(), i32Value); + // Store loaded value back at same position (replaces address) rewriter.create(loc, valueToPush, memref, stackPtr); @@ -833,8 +872,8 @@ using SharedLoadFOpConversion = MemoryLoadOpConversion; /// Generalized memory store template. -/// Pops address and value from stack, stores value to memory. -/// When IsFloat=true, bitcasts i64->f64 before storing. +/// Pops address and value from stack, truncates to i32, stores to memory. +/// When IsFloat=true, bitcasts i32->f32 before storing. template struct MemoryStoreOpConversion : public OpConversionPattern { MemoryStoreOpConversion(const TypeConverter &typeConverter, @@ -860,18 +899,22 @@ struct MemoryStoreOpConversion : public OpConversionPattern { // Pop value from stack Value one = rewriter.create(loc, 1); Value spMinus1 = rewriter.create(loc, stackPtr, one); - Value value = rewriter.create(loc, memref, spMinus1); + Value valueI64 = rewriter.create(loc, memref, spMinus1); + + // Truncate i64 -> i32 + Value valueI32 = + rewriter.create(loc, rewriter.getI32Type(), valueI64); // Store value to memory via pointer Value ptr = rewriter.create(loc, ptrType, addrValue); if constexpr (IsFloat) { - // Bitcast i64 -> f64 before storing - Value f64Value = - rewriter.create(loc, rewriter.getF64Type(), value); - rewriter.create(loc, f64Value, ptr); + // Bitcast i32 -> f32 before storing + Value f32Value = rewriter.create( + loc, rewriter.getF32Type(), valueI32); + rewriter.create(loc, f32Value, ptr); } else { - rewriter.create(loc, value, ptr); + rewriter.create(loc, valueI32, ptr); } // New stack pointer is SP-2 (popped both address and value) @@ -891,8 +934,179 @@ using SharedStoreFOpConversion = MemoryStoreOpConversion; +/// Generalized narrow-type memory load template. +/// Pops address from stack, loads a narrow type via pointer, widens to i64. +/// For float types (f16, bf16): load narrow → extf to f32 → bitcast to i32 → +/// extsi to i64 For int types (i8, i16): load narrow → extsi to i32 → extsi to +/// i64 MemTypeTag: Float16Type, BFloat16Type for floats; use NarrowIntTag +/// for ints. +template struct NarrowIntTag { + static IntegerType get(MLIRContext *ctx) { + return IntegerType::get(ctx, BitWidth); + } +}; + +template +struct NarrowLoadOpConversion : public OpConversionPattern { + NarrowLoadOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = + typename OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + auto ptrType = + LLVM::LLVMPointerType::get(rewriter.getContext(), AddressSpace); + + // Load address from stack + Value addrValue = rewriter.create(loc, memref, stackPtr); + + // Load narrow value from memory via pointer + Value ptr = rewriter.create(loc, ptrType, addrValue); + auto memType = MemTypeTag::get(rewriter.getContext()); + Value narrowVal = rewriter.create(loc, memType, ptr); + + Value i32Value; + if constexpr (IsFloat) { + // extf narrow float → f32, bitcast f32 → i32 + Value f32Val = + rewriter.create(loc, rewriter.getF32Type(), narrowVal); + i32Value = + rewriter.create(loc, rewriter.getI32Type(), f32Val); + } else { + // extsi narrow int → i32 + i32Value = rewriter.create(loc, rewriter.getI32Type(), + narrowVal); + } + + // extsi i32 → i64 for stack storage + Value valueToPush = + rewriter.create(loc, rewriter.getI64Type(), i32Value); + + rewriter.create(loc, valueToPush, memref, stackPtr); + rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); + return success(); + } +}; + +/// Generalized narrow-type memory store template. +/// Pops address and value from stack, narrows to target type, stores. +/// For float types (f16, bf16): trunci i64→i32 → bitcast→f32 → truncf→narrow → +/// store For int types (i8, i16): trunci i64→i32 → trunci→narrow → store +template +struct NarrowStoreOpConversion : public OpConversionPattern { + NarrowStoreOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = + typename OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + auto ptrType = + LLVM::LLVMPointerType::get(rewriter.getContext(), AddressSpace); + + // Pop address from stack + Value addrValue = rewriter.create(loc, memref, stackPtr); + + // Pop value from stack + Value one = rewriter.create(loc, 1); + Value spMinus1 = rewriter.create(loc, stackPtr, one); + Value valueI64 = rewriter.create(loc, memref, spMinus1); + + // Truncate i64 → i32 + Value valueI32 = + rewriter.create(loc, rewriter.getI32Type(), valueI64); + + // Store via pointer + Value ptr = rewriter.create(loc, ptrType, addrValue); + auto memType = MemTypeTag::get(rewriter.getContext()); + + if constexpr (IsFloat) { + // bitcast i32 → f32, truncf f32 → narrow float, store + Value f32Val = rewriter.create( + loc, rewriter.getF32Type(), valueI32); + Value narrowVal = rewriter.create(loc, memType, f32Val); + rewriter.create(loc, narrowVal, ptr); + } else { + // trunci i32 → narrow int, store + Value narrowVal = + rewriter.create(loc, memType, valueI32); + rewriter.create(loc, narrowVal, ptr); + } + + // New stack pointer is SP-2 + Value spMinus2 = rewriter.create(loc, spMinus1, one); + rewriter.replaceOpWithMultiple(op, {{memref, spMinus2}}); + return success(); + } +}; + +// Narrow load instantiations — global +using LoadHFOpConversion = + NarrowLoadOpConversion; +using LoadBFOpConversion = + NarrowLoadOpConversion; +using LoadI8OpConversion = + NarrowLoadOpConversion>; +using LoadI16OpConversion = + NarrowLoadOpConversion>; + +// Narrow load instantiations — shared +using SharedLoadHFOpConversion = + NarrowLoadOpConversion; +using SharedLoadBFOpConversion = + NarrowLoadOpConversion; +using SharedLoadI8OpConversion = + NarrowLoadOpConversion, false, + kWorkgroupAddressSpace>; +using SharedLoadI16OpConversion = + NarrowLoadOpConversion, false, + kWorkgroupAddressSpace>; + +// Narrow store instantiations — global +using StoreHFOpConversion = + NarrowStoreOpConversion; +using StoreBFOpConversion = + NarrowStoreOpConversion; +using StoreI8OpConversion = + NarrowStoreOpConversion>; +using StoreI16OpConversion = + NarrowStoreOpConversion>; + +// Narrow store instantiations — shared +using SharedStoreHFOpConversion = + NarrowStoreOpConversion; +using SharedStoreBFOpConversion = + NarrowStoreOpConversion; +using SharedStoreI8OpConversion = + NarrowStoreOpConversion, false, + kWorkgroupAddressSpace>; +using SharedStoreI16OpConversion = + NarrowStoreOpConversion, false, + kWorkgroupAddressSpace>; + /// Conversion pattern for forth.itof (S>F). -/// Pops i64, converts to f64 via sitofp, bitcasts back to i64, pushes. +/// Pops i64, truncates to i32, converts to f32, bitcasts to i32, sext to i64. struct IToFOpConversion : public OpConversionPattern { IToFOpConversion(const TypeConverter &typeConverter, MLIRContext *context) : OpConversionPattern(typeConverter, context) {} @@ -906,16 +1120,20 @@ struct IToFOpConversion : public OpConversionPattern { Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - // Load i64 value from top of stack + // Load i64 value from top of stack, truncate to i32 Value i64Value = rewriter.create(loc, memref, stackPtr); + Value i32Value = + rewriter.create(loc, rewriter.getI32Type(), i64Value); - // Convert i64 -> f64 via SIToFPOp - Value f64Value = - rewriter.create(loc, rewriter.getF64Type(), i64Value); + // Convert i32 -> f32 via SIToFPOp + Value f32Value = + rewriter.create(loc, rewriter.getF32Type(), i32Value); - // Bitcast f64 -> i64 for stack storage + // Bitcast f32 -> i32, sign-extend to i64 for stack storage + Value resI32 = + rewriter.create(loc, rewriter.getI32Type(), f32Value); Value result = - rewriter.create(loc, rewriter.getI64Type(), f64Value); + rewriter.create(loc, rewriter.getI64Type(), resI32); // Store result (SP unchanged — unary op) rewriter.create(loc, result, memref, stackPtr); @@ -926,7 +1144,8 @@ struct IToFOpConversion : public OpConversionPattern { }; /// Conversion pattern for forth.ftoi (F>S). -/// Pops i64 (f64 bits), bitcasts to f64, converts to i64 via fptosi, pushes. +/// Pops i64, truncates to i32 (f32 bits), bitcasts to f32, fptosi to i32, sext +/// to i64. struct FToIOpConversion : public OpConversionPattern { FToIOpConversion(const TypeConverter &typeConverter, MLIRContext *context) : OpConversionPattern(typeConverter, context) {} @@ -940,16 +1159,22 @@ struct FToIOpConversion : public OpConversionPattern { Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - // Load i64 (f64 bit pattern) from top of stack + // Load i64 from top of stack, truncate to i32 Value i64Bits = rewriter.create(loc, memref, stackPtr); + Value i32Bits = + rewriter.create(loc, rewriter.getI32Type(), i64Bits); + + // Bitcast i32 -> f32 + Value f32Value = + rewriter.create(loc, rewriter.getF32Type(), i32Bits); - // Bitcast i64 -> f64 - Value f64Value = - rewriter.create(loc, rewriter.getF64Type(), i64Bits); + // Convert f32 -> i32 via FPToSIOp + Value resI32 = + rewriter.create(loc, rewriter.getI32Type(), f32Value); - // Convert f64 -> i64 via FPToSIOp + // Sign-extend i32 -> i64 Value result = - rewriter.create(loc, rewriter.getI64Type(), f64Value); + rewriter.create(loc, rewriter.getI64Type(), resI32); // Store result (SP unchanged — unary op) rewriter.create(loc, result, memref, stackPtr); @@ -1279,6 +1504,14 @@ struct ConvertForthToMemRefPass LoadIOpConversion, StoreIOpConversion, LoadFOpConversion, StoreFOpConversion, SharedLoadIOpConversion, SharedStoreIOpConversion, SharedLoadFOpConversion, SharedStoreFOpConversion, + // Narrow memory ops (f16, bf16, i8, i16 — global + shared) + LoadHFOpConversion, StoreHFOpConversion, LoadBFOpConversion, + StoreBFOpConversion, LoadI8OpConversion, StoreI8OpConversion, + LoadI16OpConversion, StoreI16OpConversion, SharedLoadHFOpConversion, + SharedStoreHFOpConversion, SharedLoadBFOpConversion, + SharedStoreBFOpConversion, SharedLoadI8OpConversion, + SharedStoreI8OpConversion, SharedLoadI16OpConversion, + SharedStoreI16OpConversion, // Type conversions IToFOpConversion, FToIOpConversion, // Control flow diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index 409e737..1362df7 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -320,23 +320,23 @@ LogicalResult ForthParser::parseHeader() { llvm::StringRef typeToken = tokens[2]; bool isArray = false; int64_t size = 0; - BaseType baseType = BaseType::I64; + BaseType baseType = BaseType::I32; size_t lbracket = typeToken.find('['); if (lbracket != llvm::StringRef::npos) { size_t rbracket = typeToken.find(']'); if (rbracket == llvm::StringRef::npos || rbracket != typeToken.size() - 1) { return emitErrorAt(lineLoc, - "array type must use suffix [N], e.g. i64[4]"); + "array type must use suffix [N], e.g. i32[4]"); } llvm::StringRef base = typeToken.substr(0, lbracket); llvm::StringRef sizeStr = typeToken.substr(lbracket + 1, rbracket - lbracket - 1); std::string baseUpper = toUpperCase(base); - if (baseUpper == "I64") { - baseType = BaseType::I64; - } else if (baseUpper == "F64") { - baseType = BaseType::F64; + if (baseUpper == "I32") { + baseType = BaseType::I32; + } else if (baseUpper == "F32") { + baseType = BaseType::F32; } else { return emitErrorAt(lineLoc, "unsupported base type: " + base.str()); } @@ -348,10 +348,10 @@ LogicalResult ForthParser::parseHeader() { isArray = true; } else { std::string typeUpper = toUpperCase(typeToken); - if (typeUpper == "I64") { - baseType = BaseType::I64; - } else if (typeUpper == "F64") { - baseType = BaseType::F64; + if (typeUpper == "I32") { + baseType = BaseType::I32; + } else if (typeUpper == "F32") { + baseType = BaseType::F32; } else { return emitErrorAt(lineLoc, "unsupported scalar type: " + typeToken.str()); @@ -473,13 +473,13 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, .getResult(0); } - // CELLS: multiply by 8 (sizeof i64 = sizeof f64) for byte addressing + // CELLS: multiply by 4 (sizeof i32 = sizeof f32) for byte addressing if (word == "CELLS") { - Value lit8 = builder + Value lit4 = builder .create(loc, stackType, inputStack, - builder.getI64IntegerAttr(8)) + builder.getI32IntegerAttr(4)) .getResult(); - return builder.create(loc, stackType, lit8).getResult(); + return builder.create(loc, stackType, lit4).getResult(); } // Built-in operations @@ -592,6 +592,54 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, } else if (word == "SF!") { return builder.create(loc, stackType, inputStack) .getResult(); + } else if (word == "HF@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "HF!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "BF@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "BF!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "I8@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "I8!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "I16@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "I16!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SHF@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SHF!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SBF@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SBF!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SI8@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SI8!") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SI16@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SI16!") { + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "TID-X") { return builder.create(loc, stackType, inputStack) .getResult(); @@ -674,11 +722,14 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, std::to_string(depth + 1) + " nested DO/LOOP(s)"); return nullptr; } - // Load counter from the appropriate loop context + // Load counter (i32) from the appropriate loop context, + // sign-extend to i64 for PushValueOp. auto &ctx = loopStack[loopStack.size() - 1 - depth]; Value c0 = builder.create(loc, 0); - Value idx = + Value idxI32 = builder.create(loc, ctx.counter, ValueRange{c0}); + Value idx = + builder.create(loc, builder.getI64Type(), idxI32); return builder.create(loc, stackType, inputStack, idx) .getOutputStack(); } @@ -707,9 +758,9 @@ std::pair ForthParser::emitPopFlag(Location loc, Value stack) { void ForthParser::emitLoopEnd(Location loc, const LoopContext &ctx, Value step, Value &stack) { - auto i64Type = builder.getI64Type(); + auto i32Type = builder.getI32Type(); - // Load old counter, compute new = old + step, store. + // Load old counter (i32), compute new = old + step, store. Value c0 = builder.create(loc, 0); Value oldIdx = builder.create(loc, ctx.counter, ValueRange{c0}); @@ -721,8 +772,8 @@ void ForthParser::emitLoopEnd(Location loc, const LoopContext &ctx, Value step, Value oldDiff = builder.create(loc, oldIdx, ctx.limit); Value newDiff = builder.create(loc, newIdx, ctx.limit); Value xorVal = builder.create(loc, oldDiff, newDiff); - Value zero = builder.create(loc, i64Type, - builder.getI64IntegerAttr(0)); + Value zero = builder.create(loc, i32Type, + builder.getI32IntegerAttr(0)); Value crossed = builder.create(loc, arith::CmpIPredicate::slt, xorVal, zero); @@ -744,18 +795,18 @@ LogicalResult ForthParser::parseBody(Value &stack) { if (currentToken.kind == Token::Kind::Number) { Location tokenLoc = getLoc(); - int64_t value = std::stoll(currentToken.text); + int32_t value = std::stol(currentToken.text); stack = builder .create(tokenLoc, stackType, stack, - builder.getI64IntegerAttr(value)) + builder.getI32IntegerAttr(value)) .getResult(); consume(); } else if (currentToken.kind == Token::Kind::Float) { Location tokenLoc = getLoc(); - double value = std::stod(currentToken.text); + float value = std::stof(currentToken.text); stack = builder .create(tokenLoc, stackType, stack, - builder.getF64FloatAttr(value)) + builder.getF32FloatAttr(value)) .getResult(); consume(); } else if (currentToken.kind == Token::Kind::Word) { @@ -987,20 +1038,25 @@ LogicalResult ForthParser::parseBody(Value &stack) { consume(); Region *parentRegion = builder.getInsertionBlock()->getParent(); auto i64Type = builder.getI64Type(); + auto i32Type = builder.getI32Type(); - // Pop start and limit from the Forth stack. + // Pop start and limit from the Forth stack (as i64). auto popStart = builder.create(loc, stackType, i64Type, stack); Value s1 = popStart.getOutputStack(); - Value start = popStart.getValue(); + Value startI64 = popStart.getValue(); auto popLimit = builder.create(loc, stackType, i64Type, s1); Value s2 = popLimit.getOutputStack(); - Value limit = popLimit.getValue(); + Value limitI64 = popLimit.getValue(); + + // Truncate to i32 for loop counter arithmetic. + Value start = builder.create(loc, i32Type, startI64); + Value limit = builder.create(loc, i32Type, limitI64); - // Allocate counter storage. - auto counterType = MemRefType::get({1}, i64Type); + // Allocate counter storage (i32). + auto counterType = MemRefType::get({1}, i32Type); Value counter = builder.create(loc, counterType); Value c0 = builder.create(loc, 0); builder.create(loc, start, counter, ValueRange{c0}); @@ -1029,7 +1085,7 @@ LogicalResult ForthParser::parseBody(Value &stack) { auto ctx = loopStack.pop_back_val(); Value one = builder.create( - loc, builder.getI64Type(), builder.getI64IntegerAttr(1)); + loc, builder.getI32Type(), builder.getI32IntegerAttr(1)); emitLoopEnd(loc, ctx, one, stack); //=== +LOOP === @@ -1042,11 +1098,12 @@ LogicalResult ForthParser::parseBody(Value &stack) { auto ctx = loopStack.pop_back_val(); - // Pop step from data stack. + // Pop step from data stack (i64) and truncate to i32. auto popOp = builder.create( loc, forth::StackType::get(context), builder.getI64Type(), stack); stack = popOp.getOutputStack(); - Value step = popOp.getValue(); + Value step = builder.create(loc, builder.getI32Type(), + popOp.getValue()); emitLoopEnd(loc, ctx, step, stack); //=== { outside word definition === @@ -1273,9 +1330,9 @@ OwningOpRef ForthParser::parseModule() { // Build function argument types from param declarations SmallVector argTypes; for (const auto ¶m : paramDecls) { - Type elemType = param.baseType == BaseType::F64 - ? Type(builder.getF64Type()) - : Type(builder.getI64Type()); + Type elemType = param.baseType == BaseType::F32 + ? Type(builder.getF32Type()) + : Type(builder.getI32Type()); if (param.isArray) { argTypes.push_back(MemRefType::get({param.size}, elemType)); } else { @@ -1301,9 +1358,9 @@ OwningOpRef ForthParser::parseModule() { // Emit shared memory allocations at kernel entry for (const auto &shared : sharedDecls) { int64_t size = shared.isArray ? shared.size : 1; - Type elemType = shared.baseType == BaseType::F64 - ? Type(builder.getF64Type()) - : Type(builder.getI64Type()); + Type elemType = shared.baseType == BaseType::F32 + ? Type(builder.getF32Type()) + : Type(builder.getI32Type()); auto memrefType = MemRefType::get({size}, elemType); Value alloca = builder.create(loc, memrefType); alloca.getDefiningOp()->setAttr("forth.shared_name", diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.h b/lib/Translation/ForthToMLIR/ForthToMLIR.h index 4059e02..7ca39f9 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.h +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.h @@ -19,14 +19,14 @@ namespace mlir { namespace forth { /// Base element type for param/shared declarations. -enum class BaseType { I64, F64 }; +enum class BaseType { I32, F32 }; /// A declared kernel parameter: `param `. struct ParamDecl { std::string name; bool isArray = false; int64_t size = 0; - BaseType baseType = BaseType::I64; + BaseType baseType = BaseType::I32; }; /// A declared shared memory region: `shared `. @@ -34,7 +34,7 @@ struct SharedDecl { std::string name; bool isArray = false; int64_t size = 0; - BaseType baseType = BaseType::I64; + BaseType baseType = BaseType::I32; }; /// Simple token representing a Forth word or literal. @@ -115,8 +115,8 @@ class ForthParser { /// Loop context for DO/LOOP with I/J/K support. struct LoopContext { - Value counter; // memref<1xi64> alloca for the loop counter - Value limit; // i64 loop limit + Value counter; // memref<1xi32> alloca for the loop counter + Value limit; // i32 loop limit Block *body; // loop body block Block *exit; // loop exit block }; diff --git a/test/Conversion/ForthToMemRef/arithmetic.mlir b/test/Conversion/ForthToMemRef/arithmetic.mlir index a0bbf4b..db9e390 100644 --- a/test/Conversion/ForthToMemRef/arithmetic.mlir +++ b/test/Conversion/ForthToMemRef/arithmetic.mlir @@ -2,50 +2,65 @@ // CHECK-LABEL: func.func private @main -// add: pop two, arith.addi, store result +// add: pop two, trunci to i32, arith.addi i32, extsi to i64, store // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load -// CHECK: arith.addi %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.addi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// sub: pop two, arith.subi +// sub: trunci, subi i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.subi %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.subi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// mul: pop two, arith.muli +// mul: trunci, muli i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.muli %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.muli %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// div: pop two, arith.divsi +// div: trunci, divsi i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.divsi %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.divsi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// mod: pop two, arith.remsi +// mod: trunci, remsi i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.remsi %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.remsi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(20 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(20 : i32) : !forth.stack -> !forth.stack %3 = forth.addi %2 : !forth.stack -> !forth.stack - %4 = forth.constant %3(3 : i64) : !forth.stack -> !forth.stack + %4 = forth.constant %3(3 : i32) : !forth.stack -> !forth.stack %5 = forth.subi %4 : !forth.stack -> !forth.stack - %6 = forth.constant %5(4 : i64) : !forth.stack -> !forth.stack + %6 = forth.constant %5(4 : i32) : !forth.stack -> !forth.stack %7 = forth.muli %6 : !forth.stack -> !forth.stack - %8 = forth.constant %7(2 : i64) : !forth.stack -> !forth.stack + %8 = forth.constant %7(2 : i32) : !forth.stack -> !forth.stack %9 = forth.divi %8 : !forth.stack -> !forth.stack - %10 = forth.constant %9(5 : i64) : !forth.stack -> !forth.stack + %10 = forth.constant %9(5 : i32) : !forth.stack -> !forth.stack %11 = forth.mod %10 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/begin-until.mlir b/test/Conversion/ForthToMemRef/begin-until.mlir index e9b5b28..002c118 100644 --- a/test/Conversion/ForthToMemRef/begin-until.mlir +++ b/test/Conversion/ForthToMemRef/begin-until.mlir @@ -7,17 +7,24 @@ // Stack allocation and literal 10 push: // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> -// CHECK: %[[C10:.*]] = arith.constant 10 : i64 -// CHECK: memref.store %[[C10]], %[[ALLOCA]] +// CHECK: arith.constant 10 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: memref.store %{{.*}}, %[[ALLOCA]] // CHECK: cf.br ^bb1 // Loop body: push 1, subtract, dup, zero_eq, pop_flag, cond_br // CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 1 : i64 +// CHECK: arith.constant 1 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// CHECK: arith.subi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.subi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: memref.store +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.constant 0 : i32 // CHECK: arith.cmpi eq // CHECK: arith.extsi // CHECK: memref.store @@ -31,10 +38,10 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i32) : !forth.stack -> !forth.stack cf.br ^bb1(%1 : !forth.stack) ^bb1(%2: !forth.stack): - %3 = forth.constant %2(1 : i64) : !forth.stack -> !forth.stack + %3 = forth.constant %2(1 : i32) : !forth.stack -> !forth.stack %4 = forth.subi %3 : !forth.stack -> !forth.stack %5 = forth.dup %4 : !forth.stack -> !forth.stack %6 = forth.zero_eq %5 : !forth.stack -> !forth.stack diff --git a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir index 99c5c2e..18f025e 100644 --- a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir +++ b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir @@ -7,16 +7,20 @@ // Stack allocation and literal 10 push: // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> -// CHECK: %[[C10:.*]] = arith.constant 10 : i64 -// CHECK: memref.store %[[C10]], %[[ALLOCA]] +// CHECK: arith.constant 10 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: memref.store %{{.*}}, %[[ALLOCA]] // CHECK: cf.br ^bb1 // Condition block: DUP, push 0, compare >, pop_flag, cond_br // CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): // CHECK: memref.load // CHECK: memref.store -// CHECK: arith.constant 0 : i64 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 // CHECK: arith.cmpi sgt // CHECK: arith.extsi // CHECK: memref.store @@ -25,9 +29,13 @@ // Body block: push 1, subtract, branch back to condition // CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 1 : i64 +// CHECK: arith.constant 1 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// CHECK: arith.subi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.subi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: cf.br ^bb1 @@ -38,16 +46,16 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i32) : !forth.stack -> !forth.stack cf.br ^bb1(%1 : !forth.stack) ^bb1(%2: !forth.stack): %3 = forth.dup %2 : !forth.stack -> !forth.stack - %4 = forth.constant %3(0 : i64) : !forth.stack -> !forth.stack + %4 = forth.constant %3(0 : i32) : !forth.stack -> !forth.stack %5 = forth.gti %4 : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %5 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb2(%output_stack : !forth.stack), ^bb3(%output_stack : !forth.stack) ^bb2(%6: !forth.stack): - %7 = forth.constant %6(1 : i64) : !forth.stack -> !forth.stack + %7 = forth.constant %6(1 : i32) : !forth.stack -> !forth.stack %8 = forth.subi %7 : !forth.stack -> !forth.stack cf.br ^bb1(%8 : !forth.stack) ^bb3(%9: !forth.stack): diff --git a/test/Conversion/ForthToMemRef/bitwise.mlir b/test/Conversion/ForthToMemRef/bitwise.mlir index 243546a..eb4b762 100644 --- a/test/Conversion/ForthToMemRef/bitwise.mlir +++ b/test/Conversion/ForthToMemRef/bitwise.mlir @@ -2,62 +2,79 @@ // CHECK-LABEL: func.func private @main -// and: pop two, arith.andi, store result +// and: pop two, trunci to i32, arith.andi i32, extsi to i64, store // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load -// CHECK: arith.andi %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.andi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// or: pop two, arith.ori, store result +// or: trunci, ori i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.ori %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.ori %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// xor: pop two, arith.xori, store result +// xor: trunci, xori i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.xori %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.xori %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// not: load one value, xori with -1, store at same SP +// not: trunci to i32, xori with -1:i32, extsi // CHECK: memref.load -// CHECK: arith.constant -1 : i64 -// CHECK: arith.xori %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.constant -1 : i32 +// CHECK: arith.xori %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// lshift: pop two, arith.shli, store result +// lshift: trunci, shli i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.shli %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.shli %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// rshift: pop two, arith.shrui, store result +// rshift: trunci, shrui i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.shrui %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.shrui %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(3 : i64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(5 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(3 : i32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(5 : i32) : !forth.stack -> !forth.stack %3 = forth.and %2 : !forth.stack -> !forth.stack - %4 = forth.constant %3(7 : i64) : !forth.stack -> !forth.stack - %5 = forth.constant %4(8 : i64) : !forth.stack -> !forth.stack + %4 = forth.constant %3(7 : i32) : !forth.stack -> !forth.stack + %5 = forth.constant %4(8 : i32) : !forth.stack -> !forth.stack %6 = forth.or %5 : !forth.stack -> !forth.stack - %7 = forth.constant %6(15 : i64) : !forth.stack -> !forth.stack - %8 = forth.constant %7(3 : i64) : !forth.stack -> !forth.stack + %7 = forth.constant %6(15 : i32) : !forth.stack -> !forth.stack + %8 = forth.constant %7(3 : i32) : !forth.stack -> !forth.stack %9 = forth.xor %8 : !forth.stack -> !forth.stack - %10 = forth.constant %9(42 : i64) : !forth.stack -> !forth.stack + %10 = forth.constant %9(42 : i32) : !forth.stack -> !forth.stack %11 = forth.not %10 : !forth.stack -> !forth.stack - %12 = forth.constant %11(1 : i64) : !forth.stack -> !forth.stack - %13 = forth.constant %12(4 : i64) : !forth.stack -> !forth.stack + %12 = forth.constant %11(1 : i32) : !forth.stack -> !forth.stack + %13 = forth.constant %12(4 : i32) : !forth.stack -> !forth.stack %14 = forth.lshift %13 : !forth.stack -> !forth.stack - %15 = forth.constant %14(256 : i64) : !forth.stack -> !forth.stack - %16 = forth.constant %15(2 : i64) : !forth.stack -> !forth.stack + %15 = forth.constant %14(256 : i32) : !forth.stack -> !forth.stack + %16 = forth.constant %15(2 : i32) : !forth.stack -> !forth.stack %17 = forth.rshift %16 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/comparison.mlir b/test/Conversion/ForthToMemRef/comparison.mlir index 4a546b2..55bfbee 100644 --- a/test/Conversion/ForthToMemRef/comparison.mlir +++ b/test/Conversion/ForthToMemRef/comparison.mlir @@ -2,78 +2,91 @@ // CHECK-LABEL: func.func private @main -// eq: load two values, arith.cmpi eq, extsi to i64, store +// eq: load two, trunci to i32, cmpi eq on i32, extsi i1->i64, store // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load -// CHECK: arith.cmpi eq, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.cmpi eq, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store -// lt: load two values, arith.cmpi slt, extsi to i64, store +// lt: trunci, cmpi slt on i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.cmpi slt, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.cmpi slt, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store -// gt: load two values, arith.cmpi sgt, extsi to i64, store +// gt: trunci, cmpi sgt on i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.cmpi sgt, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.cmpi sgt, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store -// zero_eq: load one value, compare with 0, extsi, store at same SP +// zero_eq: trunci to i32, compare with 0:i32, extsi i1->i64, store // CHECK: memref.load -// CHECK: arith.constant 0 : i64 -// CHECK: arith.cmpi eq, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.cmpi eq, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store -// ne: load two values, arith.cmpi ne, extsi to i64, store +// ne: trunci, cmpi ne on i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.cmpi ne, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.cmpi ne, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store -// le: load two values, arith.cmpi sle, extsi to i64, store +// le: trunci, cmpi sle on i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.cmpi sle, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.cmpi sle, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store -// ge: load two values, arith.cmpi sge, extsi to i64, store +// ge: trunci, cmpi sge on i32, extsi // CHECK: memref.load // CHECK: memref.load -// CHECK: arith.cmpi sge, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.cmpi sge, %{{.*}}, %{{.*}} : i32 // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(2 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2 : i32) : !forth.stack -> !forth.stack %3 = forth.eqi %2 : !forth.stack -> !forth.stack - %4 = forth.constant %3(3 : i64) : !forth.stack -> !forth.stack - %5 = forth.constant %4(4 : i64) : !forth.stack -> !forth.stack + %4 = forth.constant %3(3 : i32) : !forth.stack -> !forth.stack + %5 = forth.constant %4(4 : i32) : !forth.stack -> !forth.stack %6 = forth.lti %5 : !forth.stack -> !forth.stack - %7 = forth.constant %6(5 : i64) : !forth.stack -> !forth.stack - %8 = forth.constant %7(6 : i64) : !forth.stack -> !forth.stack + %7 = forth.constant %6(5 : i32) : !forth.stack -> !forth.stack + %8 = forth.constant %7(6 : i32) : !forth.stack -> !forth.stack %9 = forth.gti %8 : !forth.stack -> !forth.stack - %10 = forth.constant %9(0 : i64) : !forth.stack -> !forth.stack + %10 = forth.constant %9(0 : i32) : !forth.stack -> !forth.stack %11 = forth.zero_eq %10 : !forth.stack -> !forth.stack - %12 = forth.constant %11(7 : i64) : !forth.stack -> !forth.stack - %13 = forth.constant %12(8 : i64) : !forth.stack -> !forth.stack + %12 = forth.constant %11(7 : i32) : !forth.stack -> !forth.stack + %13 = forth.constant %12(8 : i32) : !forth.stack -> !forth.stack %14 = forth.nei %13 : !forth.stack -> !forth.stack - %15 = forth.constant %14(9 : i64) : !forth.stack -> !forth.stack - %16 = forth.constant %15(10 : i64) : !forth.stack -> !forth.stack + %15 = forth.constant %14(9 : i32) : !forth.stack -> !forth.stack + %16 = forth.constant %15(10 : i32) : !forth.stack -> !forth.stack %17 = forth.lei %16 : !forth.stack -> !forth.stack - %18 = forth.constant %17(11 : i64) : !forth.stack -> !forth.stack - %19 = forth.constant %18(12 : i64) : !forth.stack -> !forth.stack + %18 = forth.constant %17(11 : i32) : !forth.stack -> !forth.stack + %19 = forth.constant %18(12 : i32) : !forth.stack -> !forth.stack %20 = forth.gei %19 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/control-flow.mlir b/test/Conversion/ForthToMemRef/control-flow.mlir index a43493a..2f704cc 100644 --- a/test/Conversion/ForthToMemRef/control-flow.mlir +++ b/test/Conversion/ForthToMemRef/control-flow.mlir @@ -7,8 +7,9 @@ // Stack allocation and literal 1 push: // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> -// CHECK: %[[C1:.*]] = arith.constant 1 : i64 -// CHECK: memref.store %[[C1]], %[[ALLOCA]] +// CHECK: arith.constant 1 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: memref.store %{{.*}}, %[[ALLOCA]] // Pop flag and conditional branch: // CHECK: %[[FLAG1:.*]] = memref.load @@ -18,26 +19,30 @@ // Then branch: push 42 // CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 42 : i64 +// CHECK: arith.constant 42 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: cf.br ^bb3 // Else branch: push 99 // CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 99 : i64 +// CHECK: arith.constant 99 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: cf.br ^bb3 // Merge block: push 0, pop flag, second conditional branch // CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 0 : i64 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: arith.cmpi ne // CHECK: cf.cond_br // Second IF true branch: push 7 // CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 7 : i64 +// CHECK: arith.constant 7 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // Final merge and return @@ -47,21 +52,21 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i32) : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %1 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) ^bb1(%2: !forth.stack): - %3 = forth.constant %2(42 : i64) : !forth.stack -> !forth.stack + %3 = forth.constant %2(42 : i32) : !forth.stack -> !forth.stack cf.br ^bb3(%3 : !forth.stack) ^bb2(%4: !forth.stack): - %5 = forth.constant %4(99 : i64) : !forth.stack -> !forth.stack + %5 = forth.constant %4(99 : i32) : !forth.stack -> !forth.stack cf.br ^bb3(%5 : !forth.stack) ^bb3(%6: !forth.stack): - %7 = forth.constant %6(0 : i64) : !forth.stack -> !forth.stack + %7 = forth.constant %6(0 : i32) : !forth.stack -> !forth.stack %output_stack_0, %flag_1 = forth.pop_flag %7 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag_1, ^bb4(%output_stack_0 : !forth.stack), ^bb5(%output_stack_0 : !forth.stack) ^bb4(%8: !forth.stack): - %9 = forth.constant %8(7 : i64) : !forth.stack -> !forth.stack + %9 = forth.constant %8(7 : i32) : !forth.stack -> !forth.stack cf.br ^bb5(%9 : !forth.stack) ^bb5(%10: !forth.stack): return diff --git a/test/Conversion/ForthToMemRef/do-loop.mlir b/test/Conversion/ForthToMemRef/do-loop.mlir index b196aff..eddda77 100644 --- a/test/Conversion/ForthToMemRef/do-loop.mlir +++ b/test/Conversion/ForthToMemRef/do-loop.mlir @@ -7,9 +7,11 @@ // Stack allocation and push 10, 0: // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> -// CHECK: arith.constant 10 : i64 +// CHECK: arith.constant 10 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// CHECK: arith.constant 0 : i64 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // Pop start and limit from stack: @@ -42,8 +44,8 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(0 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(0 : i32) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %2 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> diff --git a/test/Conversion/ForthToMemRef/float-arithmetic.mlir b/test/Conversion/ForthToMemRef/float-arithmetic.mlir index 638b9b6..6fe46ef 100644 --- a/test/Conversion/ForthToMemRef/float-arithmetic.mlir +++ b/test/Conversion/ForthToMemRef/float-arithmetic.mlir @@ -2,39 +2,51 @@ // CHECK-LABEL: func.func private @main -// addf: pop two, bitcast to f64, arith.addf, bitcast back, store +// addf: pop two, trunci to i32, bitcast i32->f32, arith.addf f32, bitcast f32->i32, extsi to i64, store // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.addf %{{.*}}, %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.addf %{{.*}}, %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// subf: bitcast, subf, bitcast -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.subf %{{.*}}, %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// subf: trunci, bitcast, subf f32, bitcast, extsi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.subf %{{.*}}, %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 -// mulf: bitcast, mulf, bitcast -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.mulf %{{.*}}, %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// mulf: trunci, bitcast, mulf f32, bitcast, extsi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.mulf %{{.*}}, %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 -// divf: bitcast, divf, bitcast -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.divf %{{.*}}, %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// divf: trunci, bitcast, divf f32, bitcast, extsi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.divf %{{.*}}, %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1.000000e+00 : f64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(2.000000e+00 : f64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1.000000e+00 : f32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2.000000e+00 : f32) : !forth.stack -> !forth.stack %3 = forth.addf %2 : !forth.stack -> !forth.stack %4 = forth.subf %3 : !forth.stack -> !forth.stack %5 = forth.mulf %4 : !forth.stack -> !forth.stack diff --git a/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir b/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir index 3153538..8116d00 100644 --- a/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir +++ b/test/Conversion/ForthToMemRef/float-math-intrinsics.mlir @@ -2,67 +2,83 @@ // CHECK-LABEL: func.func private @main -// expf: load, bitcast i64->f64, math.exp, bitcast f64->i64, store (SP unchanged) +// expf: load, trunci to i32, bitcast i32->f32, math.exp f32, bitcast f32->i32, extsi to i64, store (SP unchanged) // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: math.exp %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: math.exp %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// sqrtf: load, bitcast, math.sqrt, bitcast, store +// sqrtf: trunci, bitcast, math.sqrt f32, bitcast, extsi // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: math.sqrt %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: math.sqrt %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// logf: load, bitcast, math.log, bitcast, store +// logf: trunci, bitcast, math.log f32, bitcast, extsi // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: math.log %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: math.log %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// absf: load, bitcast, math.absf, bitcast, store +// absf: trunci, bitcast, math.absf f32, bitcast, extsi // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: math.absf %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: math.absf %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// negf: load, bitcast, arith.negf, bitcast, store +// negf: trunci, bitcast, arith.negf f32, bitcast, extsi // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.negf %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.negf %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// maxf: binary — pop two, bitcast, arith.maximumf, bitcast, store +// maxf: binary — pop two, trunci, bitcast, arith.maximumf f32, bitcast, extsi, store // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.maximumf %{{.*}}, %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.maximumf %{{.*}}, %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// minf: binary — pop two, bitcast, arith.minimumf, bitcast, store -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: arith.minimumf %{{.*}}, %{{.*}} : f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// minf: binary — trunci, bitcast, arith.minimumf f32, bitcast, extsi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.minimumf %{{.*}}, %{{.*}} : f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1.000000e+00 : f64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1.000000e+00 : f32) : !forth.stack -> !forth.stack %2 = forth.expf %1 : !forth.stack -> !forth.stack %3 = forth.sqrtf %2 : !forth.stack -> !forth.stack %4 = forth.logf %3 : !forth.stack -> !forth.stack %5 = forth.absf %4 : !forth.stack -> !forth.stack %6 = forth.negf %5 : !forth.stack -> !forth.stack - %7 = forth.constant %6(2.000000e+00 : f64) : !forth.stack -> !forth.stack + %7 = forth.constant %6(2.000000e+00 : f32) : !forth.stack -> !forth.stack %8 = forth.maxf %7 : !forth.stack -> !forth.stack %9 = forth.minf %8 : !forth.stack -> !forth.stack return diff --git a/test/Conversion/ForthToMemRef/float-memory.mlir b/test/Conversion/ForthToMemRef/float-memory.mlir index 76c3462..6c267d3 100644 --- a/test/Conversion/ForthToMemRef/float-memory.mlir +++ b/test/Conversion/ForthToMemRef/float-memory.mlir @@ -2,27 +2,29 @@ // CHECK-LABEL: func.func private @main -// loadf: load addr, inttoptr, llvm.load f64, bitcast f64->i64, store +// loadf: load addr, inttoptr, llvm.load f32, bitcast f32->i32, extsi i32->i64, store // CHECK: memref.load // CHECK: llvm.inttoptr -// CHECK: llvm.load %{{.*}} : !llvm.ptr -> f64 -// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store -// storef: load addr, load value, inttoptr, bitcast i64->f64, llvm.store +// storef: load addr, load value, trunci i64->i32, inttoptr, bitcast i32->f32, llvm.store f32 // CHECK: memref.load // CHECK: memref.load +// CHECK: arith.trunci %{{.*}} : i64 to i32 // CHECK: llvm.inttoptr -// CHECK: arith.bitcast %{{.*}} : i64 to f64 -// CHECK: llvm.store %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: llvm.store %{{.*}}, %{{.*}} : f32 module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1000 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1000 : i32) : !forth.stack -> !forth.stack %2 = forth.loadf %1 : !forth.stack -> !forth.stack - %3 = forth.constant %2(3.140000e+00 : f64) : !forth.stack -> !forth.stack - %4 = forth.constant %3(2000 : i64) : !forth.stack -> !forth.stack + %3 = forth.constant %2(3.140000e+00 : f32) : !forth.stack -> !forth.stack + %4 = forth.constant %3(2000 : i32) : !forth.stack -> !forth.stack %5 = forth.storef %4 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/leave.mlir b/test/Conversion/ForthToMemRef/leave.mlir index d010ed0..cfc0ce4 100644 --- a/test/Conversion/ForthToMemRef/leave.mlir +++ b/test/Conversion/ForthToMemRef/leave.mlir @@ -14,8 +14,8 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(0 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(0 : i32) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %2 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> diff --git a/test/Conversion/ForthToMemRef/literal.mlir b/test/Conversion/ForthToMemRef/literal.mlir index fce38e4..da320f4 100644 --- a/test/Conversion/ForthToMemRef/literal.mlir +++ b/test/Conversion/ForthToMemRef/literal.mlir @@ -3,7 +3,8 @@ // CHECK-LABEL: func.func private @main // CHECK: memref.alloca() : memref<256xi64> // CHECK: arith.constant 0 : index -// CHECK: arith.constant 42 : i64 +// CHECK: arith.constant 42 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: arith.constant 1 : index // CHECK: arith.addi %{{.*}}, %{{.*}} : index // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<256xi64> @@ -11,7 +12,7 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(42 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(42 : i32) : !forth.stack -> !forth.stack return } } diff --git a/test/Conversion/ForthToMemRef/memory-ops.mlir b/test/Conversion/ForthToMemRef/memory-ops.mlir index 81c8e00..8596a42 100644 --- a/test/Conversion/ForthToMemRef/memory-ops.mlir +++ b/test/Conversion/ForthToMemRef/memory-ops.mlir @@ -2,39 +2,43 @@ // CHECK-LABEL: func.func private @main -// load (@): pop address, inttoptr, llvm.load, store back +// load (@): pop address, inttoptr, llvm.load i32, extsi i32->i64, store back // CHECK: memref.load %{{.*}}[%{{.*}}] : memref<256xi64> // CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr -// CHECK: llvm.load %{{.*}} : !llvm.ptr -> i64 +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<256xi64> -// store (!): pop address, pop value, inttoptr, llvm.store +// store (!): pop address, pop value, trunci i64->i32, inttoptr, llvm.store i32 // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load +// CHECK: arith.trunci %{{.*}} : i64 to i32 // CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr -// CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr +// CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr -// shared load (S@): pop address, inttoptr shared addrspace, llvm.load +// shared load (S@): pop address, inttoptr shared addrspace, llvm.load i32, extsi // CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<3> -// CHECK: llvm.load %{{.*}} : !llvm.ptr<3> -> i64 +// CHECK: llvm.load %{{.*}} : !llvm.ptr<3> -> i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 -// shared store (S!): pop address + value, inttoptr shared addrspace, llvm.store +// shared store (S!): pop address + value, trunci, inttoptr shared addrspace, llvm.store i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 // CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<3> -// CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr<3> +// CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<3> module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i32) : !forth.stack -> !forth.stack %2 = forth.loadi %1 : !forth.stack -> !forth.stack - %3 = forth.constant %2(42 : i64) : !forth.stack -> !forth.stack - %4 = forth.constant %3(100 : i64) : !forth.stack -> !forth.stack + %3 = forth.constant %2(42 : i32) : !forth.stack -> !forth.stack + %4 = forth.constant %3(100 : i32) : !forth.stack -> !forth.stack %5 = forth.storei %4 : !forth.stack -> !forth.stack - %6 = forth.constant %5(2 : i64) : !forth.stack -> !forth.stack + %6 = forth.constant %5(2 : i32) : !forth.stack -> !forth.stack %7 = forth.shared_loadi %6 : !forth.stack -> !forth.stack - %8 = forth.constant %7(9 : i64) : !forth.stack -> !forth.stack - %9 = forth.constant %8(3 : i64) : !forth.stack -> !forth.stack + %8 = forth.constant %7(9 : i32) : !forth.stack -> !forth.stack + %9 = forth.constant %8(3 : i32) : !forth.stack -> !forth.stack %10 = forth.shared_storei %9 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/narrow-memory.mlir b/test/Conversion/ForthToMemRef/narrow-memory.mlir new file mode 100644 index 0000000..f7d6f12 --- /dev/null +++ b/test/Conversion/ForthToMemRef/narrow-memory.mlir @@ -0,0 +1,154 @@ +// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s + +// CHECK-LABEL: func.func private @main + +// HF@ (loadhf): load addr, inttoptr, llvm.load f16, extf f16->f32, bitcast f32->i32, extsi i32->i64, store +// CHECK: memref.load +// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> f16 +// CHECK: arith.extf %{{.*}} : f16 to f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: memref.store + +// HF! (storehf): pop addr, pop value, trunci i64->i32, inttoptr, bitcast i32->f32, truncf f32->f16, store +// CHECK: memref.load +// CHECK: arith.subi +// CHECK: memref.load +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.truncf %{{.*}} : f32 to f16 +// CHECK: llvm.store %{{.*}}, %{{.*}} : f16 + +// BF@ (loadbf): load addr, inttoptr, llvm.load bf16, extf bf16->f32, bitcast f32->i32, extsi i32->i64 +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> bf16 +// CHECK: arith.extf %{{.*}} : bf16 to f32 +// CHECK: arith.bitcast %{{.*}} : f32 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 + +// BF! (storebf): trunci i64->i32, bitcast i32->f32, truncf f32->bf16, store +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: llvm.inttoptr +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.truncf %{{.*}} : f32 to bf16 +// CHECK: llvm.store %{{.*}}, %{{.*}} : bf16 + +// I8@ (loadi8): load addr, inttoptr, llvm.load i8, extsi i8->i32, extsi i32->i64 +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> i8 +// CHECK: arith.extsi %{{.*}} : i8 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 + +// I8! (storei8): trunci i64->i32, trunci i32->i8, store +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: llvm.inttoptr +// CHECK: arith.trunci %{{.*}} : i32 to i8 +// CHECK: llvm.store %{{.*}}, %{{.*}} : i8 + +// I16@ (loadi16): llvm.load i16, extsi i16->i32, extsi i32->i64 +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> i16 +// CHECK: arith.extsi %{{.*}} : i16 to i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 + +// I16! (storei16): trunci i64->i32, trunci i32->i16, store +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: llvm.inttoptr +// CHECK: arith.trunci %{{.*}} : i32 to i16 +// CHECK: llvm.store %{{.*}}, %{{.*}} : i16 + +// SHF@ (shared_loadhf): inttoptr to shared ptr, llvm.load f16 +// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<3> +// CHECK: llvm.load %{{.*}} : !llvm.ptr<3> -> f16 +// CHECK: arith.extf %{{.*}} : f16 to f32 + +// SHF! (shared_storehf): truncf f32->f16, store to shared +// CHECK: arith.bitcast %{{.*}} : i32 to f32 +// CHECK: arith.truncf %{{.*}} : f32 to f16 +// CHECK: llvm.store %{{.*}}, %{{.*}} : f16, !llvm.ptr<3> + +// SBF@ (shared_loadbf): inttoptr shared, llvm.load bf16 +// CHECK: llvm.load %{{.*}} : !llvm.ptr<3> -> bf16 +// CHECK: arith.extf %{{.*}} : bf16 to f32 + +// SBF! (shared_storebf): truncf f32->bf16, store to shared +// CHECK: arith.truncf %{{.*}} : f32 to bf16 +// CHECK: llvm.store %{{.*}}, %{{.*}} : bf16, !llvm.ptr<3> + +// SI8@ (shared_loadi8): llvm.load i8 from shared, extsi i8->i32 +// CHECK: llvm.load %{{.*}} : !llvm.ptr<3> -> i8 +// CHECK: arith.extsi %{{.*}} : i8 to i32 + +// SI8! (shared_storei8): trunci i32->i8, store to shared +// CHECK: arith.trunci %{{.*}} : i32 to i8 +// CHECK: llvm.store %{{.*}}, %{{.*}} : i8, !llvm.ptr<3> + +// SI16@ (shared_loadi16): llvm.load i16 from shared, extsi +// CHECK: llvm.load %{{.*}} : !llvm.ptr<3> -> i16 +// CHECK: arith.extsi %{{.*}} : i16 to i32 + +// SI16! (shared_storei16): trunci i32->i16, store to shared +// CHECK: arith.trunci %{{.*}} : i32 to i16 +// CHECK: llvm.store %{{.*}}, %{{.*}} : i16, !llvm.ptr<3> + +module { + func.func private @main() { + %0 = forth.stack !forth.stack + // HF@ test: push address, load f16 + %1 = forth.constant %0(1000 : i32) : !forth.stack -> !forth.stack + %2 = forth.loadhf %1 : !forth.stack -> !forth.stack + // HF! test: push value, push address, store f16 + %3 = forth.constant %2(3.140000e+00 : f32) : !forth.stack -> !forth.stack + %4 = forth.constant %3(2000 : i32) : !forth.stack -> !forth.stack + %5 = forth.storehf %4 : !forth.stack -> !forth.stack + // BF@ test + %6 = forth.constant %5(3000 : i32) : !forth.stack -> !forth.stack + %7 = forth.loadbf %6 : !forth.stack -> !forth.stack + // BF! test + %8 = forth.constant %7(2.710000e+00 : f32) : !forth.stack -> !forth.stack + %9 = forth.constant %8(4000 : i32) : !forth.stack -> !forth.stack + %10 = forth.storebf %9 : !forth.stack -> !forth.stack + // I8@ test + %11 = forth.constant %10(5000 : i32) : !forth.stack -> !forth.stack + %12 = forth.loadi8 %11 : !forth.stack -> !forth.stack + // I8! test + %13 = forth.constant %12(42 : i32) : !forth.stack -> !forth.stack + %14 = forth.constant %13(6000 : i32) : !forth.stack -> !forth.stack + %15 = forth.storei8 %14 : !forth.stack -> !forth.stack + // I16@ test + %16 = forth.constant %15(7000 : i32) : !forth.stack -> !forth.stack + %17 = forth.loadi16 %16 : !forth.stack -> !forth.stack + // I16! test + %18 = forth.constant %17(999 : i32) : !forth.stack -> !forth.stack + %19 = forth.constant %18(8000 : i32) : !forth.stack -> !forth.stack + %20 = forth.storei16 %19 : !forth.stack -> !forth.stack + // SHF@ test (shared) + %21 = forth.constant %20(100 : i32) : !forth.stack -> !forth.stack + %22 = forth.shared_loadhf %21 : !forth.stack -> !forth.stack + // SHF! test (shared) + %23 = forth.constant %22(1.500000e+00 : f32) : !forth.stack -> !forth.stack + %24 = forth.constant %23(200 : i32) : !forth.stack -> !forth.stack + %25 = forth.shared_storehf %24 : !forth.stack -> !forth.stack + // SBF@ test (shared) + %26 = forth.constant %25(300 : i32) : !forth.stack -> !forth.stack + %27 = forth.shared_loadbf %26 : !forth.stack -> !forth.stack + // SBF! test (shared) + %28 = forth.constant %27(2.500000e+00 : f32) : !forth.stack -> !forth.stack + %29 = forth.constant %28(400 : i32) : !forth.stack -> !forth.stack + %30 = forth.shared_storebf %29 : !forth.stack -> !forth.stack + // SI8@ test (shared) + %31 = forth.constant %30(500 : i32) : !forth.stack -> !forth.stack + %32 = forth.shared_loadi8 %31 : !forth.stack -> !forth.stack + // SI8! test (shared) + %33 = forth.constant %32(7 : i32) : !forth.stack -> !forth.stack + %34 = forth.constant %33(600 : i32) : !forth.stack -> !forth.stack + %35 = forth.shared_storei8 %34 : !forth.stack -> !forth.stack + // SI16@ test (shared) + %36 = forth.constant %35(700 : i32) : !forth.stack -> !forth.stack + %37 = forth.shared_loadi16 %36 : !forth.stack -> !forth.stack + // SI16! test (shared) + %38 = forth.constant %37(123 : i32) : !forth.stack -> !forth.stack + %39 = forth.constant %38(800 : i32) : !forth.stack -> !forth.stack + %40 = forth.shared_storei16 %39 : !forth.stack -> !forth.stack + return + } +} diff --git a/test/Conversion/ForthToMemRef/nested-control-flow.mlir b/test/Conversion/ForthToMemRef/nested-control-flow.mlir index 8c09f33..ecffff4 100644 --- a/test/Conversion/ForthToMemRef/nested-control-flow.mlir +++ b/test/Conversion/ForthToMemRef/nested-control-flow.mlir @@ -2,14 +2,16 @@ // === Nested IF: 1 IF 2 IF 3 THEN THEN === // CHECK-LABEL: func.func private @TEST__NESTED__IF -// CHECK: arith.constant 1 : i64 +// CHECK: arith.constant 1 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: arith.cmpi ne // CHECK: cf.cond_br %{{.*}}, ^bb1({{.*}}), ^bb2({{.*}}) // Inner IF: push 2, pop_flag, cond_br // CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 2 : i64 +// CHECK: arith.constant 2 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: arith.cmpi ne // CHECK: cf.cond_br %{{.*}}, ^bb3({{.*}}), ^bb4({{.*}}) @@ -20,7 +22,8 @@ // Inner true: push 3 // CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 3 : i64 +// CHECK: arith.constant 3 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: cf.br ^bb4 @@ -32,8 +35,10 @@ // CHECK-LABEL: func.func private @TEST__IF__INSIDE__DO // DO loop setup: pop start/limit, alloca counter -// CHECK: arith.constant 10 : i64 -// CHECK: arith.constant 0 : i64 +// CHECK: arith.constant 10 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: %[[COUNTER1:.*]] = memref.alloca() : memref<1xi64> // CHECK: memref.store %{{.*}}, %[[COUNTER1]] // CHECK: cf.br ^bb1 @@ -48,7 +53,10 @@ // CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): // CHECK: memref.load %[[COUNTER1]] // CHECK: memref.store -// CHECK: arith.constant 5 : i64 +// CHECK: arith.constant 5 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 // CHECK: arith.cmpi sgt // CHECK: arith.cmpi ne // CHECK: cf.cond_br @@ -74,8 +82,10 @@ // CHECK-LABEL: func.func private @TEST__NESTED__DO__J // Outer DO setup -// CHECK: arith.constant 3 : i64 -// CHECK: arith.constant 0 : i64 +// CHECK: arith.constant 3 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: %[[OUTER:.*]] = memref.alloca() : memref<1xi64> // CHECK: memref.store %{{.*}}, %[[OUTER]] // CHECK: cf.br ^bb1 @@ -88,8 +98,10 @@ // Outer loop body: inner DO setup (4 0 DO) // CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 4 : i64 -// CHECK: arith.constant 0 : i64 +// CHECK: arith.constant 4 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: %[[INNER:.*]] = memref.alloca() : memref<1xi64> // CHECK: memref.store %{{.*}}, %[[INNER]] // CHECK: cf.br ^bb4 @@ -110,7 +122,10 @@ // CHECK: memref.store // CHECK: memref.load %[[INNER]] // CHECK: memref.store -// CHECK: arith.addi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.addi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: memref.load %[[INNER]] // CHECK: arith.addi @@ -128,7 +143,8 @@ // CHECK-LABEL: func.func private @TEST__WHILE__INSIDE__IF // Push 5, pop_flag, cond_br (IF) -// CHECK: arith.constant 5 : i64 +// CHECK: arith.constant 5 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: arith.cmpi ne // CHECK: cf.cond_br %{{.*}}, ^bb1({{.*}}), ^bb2({{.*}}) @@ -150,8 +166,12 @@ // WHILE body: push 1, subtract, loop back // CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): -// CHECK: arith.constant 1 : i64 -// CHECK: arith.subi +// CHECK: arith.constant 1 : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.subi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: cf.br ^bb3 // WHILE exit -> merge with IF false @@ -160,24 +180,24 @@ module { func.func private @TEST__NESTED__IF(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.constant %arg0(1 : i64) : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(1 : i32) : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %0 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) ^bb1(%1: !forth.stack): - %2 = forth.constant %1(2 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2 : i32) : !forth.stack -> !forth.stack %output_stack_0, %flag_1 = forth.pop_flag %2 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag_1, ^bb3(%output_stack_0 : !forth.stack), ^bb4(%output_stack_0 : !forth.stack) ^bb2(%3: !forth.stack): return %3 : !forth.stack ^bb3(%4: !forth.stack): - %5 = forth.constant %4(3 : i64) : !forth.stack -> !forth.stack + %5 = forth.constant %4(3 : i32) : !forth.stack -> !forth.stack cf.br ^bb4(%5 : !forth.stack) ^bb4(%6: !forth.stack): cf.br ^bb2(%6 : !forth.stack) } func.func private @TEST__IF__INSIDE__DO(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.constant %arg0(10 : i64) : !forth.stack -> !forth.stack - %1 = forth.constant %0(0 : i64) : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(10 : i32) : !forth.stack -> !forth.stack + %1 = forth.constant %0(0 : i32) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %1 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> @@ -193,7 +213,7 @@ module { %c0_3 = arith.constant 0 : index %6 = memref.load %alloca[%c0_3] : memref<1xi64> %7 = forth.push_value %5, %6 : !forth.stack, i64 -> !forth.stack - %8 = forth.constant %7(5 : i64) : !forth.stack -> !forth.stack + %8 = forth.constant %7(5 : i32) : !forth.stack -> !forth.stack %9 = forth.gti %8 : !forth.stack -> !forth.stack %output_stack_4, %flag = forth.pop_flag %9 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb4(%output_stack_4 : !forth.stack), ^bb5(%output_stack_4 : !forth.stack) @@ -213,8 +233,8 @@ module { cf.br ^bb1(%14 : !forth.stack) } func.func private @TEST__NESTED__DO__J(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.constant %arg0(3 : i64) : !forth.stack -> !forth.stack - %1 = forth.constant %0(0 : i64) : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(3 : i32) : !forth.stack -> !forth.stack + %1 = forth.constant %0(0 : i32) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %1 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> @@ -227,8 +247,8 @@ module { %4 = arith.cmpi slt, %3, %value_1 : i64 cf.cond_br %4, ^bb2(%2 : !forth.stack), ^bb3(%2 : !forth.stack) ^bb2(%5: !forth.stack): - %6 = forth.constant %5(4 : i64) : !forth.stack -> !forth.stack - %7 = forth.constant %6(0 : i64) : !forth.stack -> !forth.stack + %6 = forth.constant %5(4 : i32) : !forth.stack -> !forth.stack + %7 = forth.constant %6(0 : i32) : !forth.stack -> !forth.stack %output_stack_3, %value_4 = forth.pop %7 : !forth.stack -> !forth.stack, i64 %output_stack_5, %value_6 = forth.pop %output_stack_3 : !forth.stack -> !forth.stack, i64 %alloca_7 = memref.alloca() : memref<1xi64> @@ -265,7 +285,7 @@ module { cf.br ^bb1(%20 : !forth.stack) } func.func private @TEST__WHILE__INSIDE__IF(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.constant %arg0(5 : i64) : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(5 : i32) : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %0 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) ^bb1(%1: !forth.stack): @@ -277,7 +297,7 @@ module { %output_stack_0, %flag_1 = forth.pop_flag %4 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag_1, ^bb4(%output_stack_0 : !forth.stack), ^bb5(%output_stack_0 : !forth.stack) ^bb4(%5: !forth.stack): - %6 = forth.constant %5(1 : i64) : !forth.stack -> !forth.stack + %6 = forth.constant %5(1 : i32) : !forth.stack -> !forth.stack %7 = forth.subi %6 : !forth.stack -> !forth.stack cf.br ^bb3(%7 : !forth.stack) ^bb5(%8: !forth.stack): diff --git a/test/Conversion/ForthToMemRef/param-ref.mlir b/test/Conversion/ForthToMemRef/param-ref.mlir index 5e357a5..57575a7 100644 --- a/test/Conversion/ForthToMemRef/param-ref.mlir +++ b/test/Conversion/ForthToMemRef/param-ref.mlir @@ -1,13 +1,13 @@ // RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s -// CHECK-LABEL: func.func private @main(%{{.*}}: memref<256xi64> {forth.param_name = "data"}) +// CHECK-LABEL: func.func private @main(%{{.*}}: memref<256xi32> {forth.param_name = "data"}) // CHECK: memref.alloca() : memref<256xi64> -// CHECK: memref.extract_aligned_pointer_as_index %{{.*}} : memref<256xi64> -> index +// CHECK: memref.extract_aligned_pointer_as_index %{{.*}} : memref<256xi32> -> index // CHECK: arith.index_cast %{{.*}} : index to i64 // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<256xi64> module { - func.func private @main(%arg0: memref<256xi64> {forth.param_name = "data"}) { + func.func private @main(%arg0: memref<256xi32> {forth.param_name = "data"}) { %0 = forth.stack !forth.stack %1 = forth.param_ref %0 "data" : !forth.stack -> !forth.stack return diff --git a/test/Conversion/ForthToMemRef/stack-manipulation.mlir b/test/Conversion/ForthToMemRef/stack-manipulation.mlir index 4da0597..4028d00 100644 --- a/test/Conversion/ForthToMemRef/stack-manipulation.mlir +++ b/test/Conversion/ForthToMemRef/stack-manipulation.mlir @@ -76,9 +76,9 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack - %2 = forth.constant %1(2 : i64) : !forth.stack -> !forth.stack - %3 = forth.constant %2(3 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i32) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2 : i32) : !forth.stack -> !forth.stack + %3 = forth.constant %2(3 : i32) : !forth.stack -> !forth.stack %4 = forth.dup %3 : !forth.stack -> !forth.stack %5 = forth.drop %4 : !forth.stack -> !forth.stack %6 = forth.swap %5 : !forth.stack -> !forth.stack @@ -86,9 +86,9 @@ module { %8 = forth.rot %7 : !forth.stack -> !forth.stack %9 = forth.nip %8 : !forth.stack -> !forth.stack %10 = forth.tuck %9 : !forth.stack -> !forth.stack - %11 = forth.constant %10(2 : i64) : !forth.stack -> !forth.stack + %11 = forth.constant %10(2 : i32) : !forth.stack -> !forth.stack %12 = forth.pick %11 : !forth.stack -> !forth.stack - %13 = forth.constant %12(2 : i64) : !forth.stack -> !forth.stack + %13 = forth.constant %12(2 : i32) : !forth.stack -> !forth.stack %14 = forth.roll %13 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/user-defined-words.mlir b/test/Conversion/ForthToMemRef/user-defined-words.mlir index 6634934..96a2216 100644 --- a/test/Conversion/ForthToMemRef/user-defined-words.mlir +++ b/test/Conversion/ForthToMemRef/user-defined-words.mlir @@ -4,7 +4,10 @@ // (memref<256xi64>, index) -> (memref<256xi64>, index) // CHECK-LABEL: func.func private @double(%{{.*}}: memref<256xi64>, %{{.*}}: index) -> (memref<256xi64>, index) // CHECK: memref.load -// CHECK: arith.addi +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.trunci %{{.*}} : i64 to i32 +// CHECK: arith.addi %{{.*}}, %{{.*}} : i32 +// CHECK: arith.extsi %{{.*}} : i32 to i64 // CHECK: memref.store // CHECK: return %{{.*}}, %{{.*}} : memref<256xi64>, index @@ -19,7 +22,7 @@ module { } func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.constant %0(5 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(5 : i32) : !forth.stack -> !forth.stack %2 = call @double(%1) : (!forth.stack) -> !forth.stack return } diff --git a/test/Pipeline/attention.forth b/test/Pipeline/attention.forth index 094df16..462544a 100644 --- a/test/Pipeline/attention.forth +++ b/test/Pipeline/attention.forth @@ -5,14 +5,14 @@ \ CHECK: gpu.binary @warpforth_module \! kernel attention -\! param Q f64[16] -\! param K f64[16] -\! param V f64[16] -\! param O f64[16] -\! param SEQ_LEN i64 -\! param HEAD_DIM i64 -\! shared SCORES f64[4] -\! shared SCRATCH f64[4] +\! param Q f32[16] +\! param K f32[16] +\! param V f32[16] +\! param O f32[16] +\! param SEQ_LEN i32 +\! param HEAD_DIM i32 +\! shared SCORES f32[4] +\! shared SCRATCH f32[4] \ row = BID-X, t = TID-X BID-X diff --git a/test/Pipeline/barrier.forth b/test/Pipeline/barrier.forth index ef1c2b8..19edb8f 100644 --- a/test/Pipeline/barrier.forth +++ b/test/Pipeline/barrier.forth @@ -5,6 +5,6 @@ \ MID: gpu.barrier \ MID: gpu.return \! kernel main -\! param DATA i64[256] -\! shared SCRATCH i64[256] +\! param DATA i32[256] +\! shared SCRATCH i32[256] GLOBAL-ID CELLS SCRATCH + @ BARRIER GLOBAL-ID CELLS DATA + ! diff --git a/test/Pipeline/begin-until.forth b/test/Pipeline/begin-until.forth index 189b487..44a3c5c 100644 --- a/test/Pipeline/begin-until.forth +++ b/test/Pipeline/begin-until.forth @@ -6,11 +6,11 @@ \ Verify intermediate MLIR: gpu.func with loop back-edge and conditional branch \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<4xi32> {forth.param_name = "DATA"}) kernel \ MID: cf.br \ MID: cf.cond_br \ MID: gpu.return \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 10 BEGIN 1 - DUP 0= UNTIL DATA 0 CELLS + ! diff --git a/test/Pipeline/begin-while-repeat.forth b/test/Pipeline/begin-while-repeat.forth index b2998fe..1fddc71 100644 --- a/test/Pipeline/begin-while-repeat.forth +++ b/test/Pipeline/begin-while-repeat.forth @@ -6,11 +6,11 @@ \ Verify intermediate MLIR: gpu.func with conditional branch \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<4xi32> {forth.param_name = "DATA"}) kernel \ MID: cf.br \ MID: cf.cond_br \ MID: gpu.return \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 10 BEGIN DUP 0 > WHILE 1 - REPEAT DATA 0 CELLS + ! diff --git a/test/Pipeline/control-flow.forth b/test/Pipeline/control-flow.forth index 6fac4d8..db5e7ab 100644 --- a/test/Pipeline/control-flow.forth +++ b/test/Pipeline/control-flow.forth @@ -6,12 +6,12 @@ \ Verify intermediate MLIR: gpu.func with conditional branching \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<256xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<256xi32> {forth.param_name = "DATA"}) kernel \ MID: memref.load \ MID: arith.cmpi ne \ MID: cf.cond_br \ MID: gpu.return \! kernel main -\! param DATA i64[256] +\! param DATA i32[256] DATA @ 5 > IF DATA @ 1 + DATA ! THEN diff --git a/test/Pipeline/do-loop.forth b/test/Pipeline/do-loop.forth index 57fb4b8..7d58eae 100644 --- a/test/Pipeline/do-loop.forth +++ b/test/Pipeline/do-loop.forth @@ -6,11 +6,11 @@ \ Verify intermediate MLIR: gpu.func with loop structure \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<4xi32> {forth.param_name = "DATA"}) kernel \ MID: cf.br \ MID: cf.cond_br \ MID: gpu.return \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 10 0 DO I LOOP DATA 0 CELLS + ! diff --git a/test/Pipeline/exit.forth b/test/Pipeline/exit.forth index 6dfa82b..5898976 100644 --- a/test/Pipeline/exit.forth +++ b/test/Pipeline/exit.forth @@ -2,6 +2,6 @@ \ CHECK: gpu.binary @warpforth_module \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] : DO-EXIT 1 IF EXIT THEN 42 ; DO-EXIT DATA 0 CELLS + ! diff --git a/test/Pipeline/float-math-intrinsics.forth b/test/Pipeline/float-math-intrinsics.forth index 15d4e7f..53c581e 100644 --- a/test/Pipeline/float-math-intrinsics.forth +++ b/test/Pipeline/float-math-intrinsics.forth @@ -4,7 +4,7 @@ \ CHECK: gpu.binary @warpforth_module \! kernel main -\! param data f64[256] +\! param data f32[256] GLOBAL-ID CELLS data + F@ FABS FEXP FSQRT FLOG FNEG GLOBAL-ID CELLS data + F@ diff --git a/test/Pipeline/float-pipeline.forth b/test/Pipeline/float-pipeline.forth index 2036b97..907b49b 100644 --- a/test/Pipeline/float-pipeline.forth +++ b/test/Pipeline/float-pipeline.forth @@ -1,20 +1,20 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s \ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID -\ Verify that Forth with f64 params through the full pipeline produces a gpu.binary +\ Verify that Forth with f32 params through the full pipeline produces a gpu.binary \ CHECK: gpu.binary @warpforth_module \ Verify intermediate MLIR structure at the memref+gpu stage \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<256xf64> {forth.param_name = "DATA"}, %arg1: f64 {forth.param_name = "SCALE"}) kernel +\ MID: gpu.func @main(%arg0: memref<256xf32> {forth.param_name = "DATA"}, %arg1: f32 {forth.param_name = "SCALE"}) kernel \ MID: memref.alloca() : memref<256xi64> \ MID: memref.extract_aligned_pointer_as_index %arg0 -\ MID: arith.bitcast %{{.*}} : f64 to i64 +\ MID: arith.bitcast %{{.*}} : f32 to i32 \ MID: gpu.return \! kernel main -\! param DATA f64[256] -\! param SCALE f64 +\! param DATA f32[256] +\! param SCALE f32 GLOBAL-ID CELLS DATA + F@ SCALE F* GLOBAL-ID CELLS DATA + F! diff --git a/test/Pipeline/full-pipeline.forth b/test/Pipeline/full-pipeline.forth index c2c77fb..a41066c 100644 --- a/test/Pipeline/full-pipeline.forth +++ b/test/Pipeline/full-pipeline.forth @@ -6,7 +6,7 @@ \ Verify intermediate MLIR structure at the memref+gpu stage \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<256xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<256xi32> {forth.param_name = "DATA"}) kernel \ MID: memref.alloca() : memref<256xi64> \ MID: gpu.thread_id x \ MID: memref.extract_aligned_pointer_as_index %arg0 @@ -15,7 +15,7 @@ \ MID: gpu.return \! kernel main -\! param DATA i64[256] +\! param DATA i32[256] GLOBAL-ID CELLS DATA + @ 1 + GLOBAL-ID CELLS DATA + ! diff --git a/test/Pipeline/interleaved-control-flow.forth b/test/Pipeline/interleaved-control-flow.forth index 2b5573d..de3adcd 100644 --- a/test/Pipeline/interleaved-control-flow.forth +++ b/test/Pipeline/interleaved-control-flow.forth @@ -7,7 +7,7 @@ \ Verify intermediate MLIR: gpu.func with cf branches, no scf ops \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<4xi32> {forth.param_name = "DATA"}) kernel \ MID: gpu.return \ Multi-WHILE: two cond_br exits + one unconditional back-edge @@ -23,7 +23,7 @@ \ MID: cf.br \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] : multi-while BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT DROP THEN ; : while-until diff --git a/test/Pipeline/leave.forth b/test/Pipeline/leave.forth index 3fbe231..74c4e11 100644 --- a/test/Pipeline/leave.forth +++ b/test/Pipeline/leave.forth @@ -6,11 +6,11 @@ \ Verify intermediate MLIR: gpu.func with CF control flow \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<4xi32> {forth.param_name = "DATA"}) kernel \ MID: cf.br \ MID: cf.cond_br \ MID: gpu.return \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 10 0 DO LEAVE LOOP DATA 0 CELLS + ! diff --git a/test/Pipeline/local-variables.forth b/test/Pipeline/local-variables.forth index 4e907c2..03c46de 100644 --- a/test/Pipeline/local-variables.forth +++ b/test/Pipeline/local-variables.forth @@ -4,7 +4,7 @@ \ CHECK: gpu.binary @warpforth_module \! kernel main -\! param DATA i64[256] +\! param DATA i32[256] : ADD3 { a b c -- } a b + c + ; 1 2 3 ADD3 GLOBAL-ID CELLS DATA + ! diff --git a/test/Pipeline/matmul-naive.forth b/test/Pipeline/matmul-naive.forth index 4cae506..2996c28 100644 --- a/test/Pipeline/matmul-naive.forth +++ b/test/Pipeline/matmul-naive.forth @@ -5,12 +5,12 @@ \ CHECK: gpu.binary @warpforth_module \ Verify the kernel signature at the memref+gpu stage. -\ MID: gpu.func @main(%arg0: memref<8xi64> {forth.param_name = "A"}, %arg1: memref<12xi64> {forth.param_name = "B"}, %arg2: memref<6xi64> {forth.param_name = "C"}) kernel +\ MID: gpu.func @main(%arg0: memref<8xi32> {forth.param_name = "A"}, %arg1: memref<12xi32> {forth.param_name = "B"}, %arg2: memref<6xi32> {forth.param_name = "C"}) kernel \! kernel main -\! param A i64[8] -\! param B i64[12] -\! param C i64[6] +\! param A i32[8] +\! param B i32[12] +\! param C i32[6] \ M=2, N=3, K=4. One thread computes C[row, col] where gid = row*N + col. GLOBAL-ID diff --git a/test/Pipeline/multi-param.forth b/test/Pipeline/multi-param.forth index 3cd9b43..56f125e 100644 --- a/test/Pipeline/multi-param.forth +++ b/test/Pipeline/multi-param.forth @@ -5,7 +5,7 @@ \ CHECK: gpu.binary @warpforth_module \ Verify both params appear in kernel signature and are used correctly -\ MID: gpu.func @main(%arg0: memref<256xi64> {forth.param_name = "INPUT"}, %arg1: memref<256xi64> {forth.param_name = "OUTPUT"}) kernel +\ MID: gpu.func @main(%arg0: memref<256xi32> {forth.param_name = "INPUT"}, %arg1: memref<256xi32> {forth.param_name = "OUTPUT"}) kernel \ MID: memref.extract_aligned_pointer_as_index %arg0 \ MID: llvm.load \ MID: memref.extract_aligned_pointer_as_index %arg1 @@ -13,8 +13,8 @@ \ MID: gpu.return \! kernel main -\! param INPUT i64[256] -\! param OUTPUT i64[256] +\! param INPUT i32[256] +\! param OUTPUT i32[256] GLOBAL-ID CELLS INPUT + @ 2 * GLOBAL-ID CELLS OUTPUT + ! diff --git a/test/Pipeline/narrow-memory.forth b/test/Pipeline/narrow-memory.forth new file mode 100644 index 0000000..5a7ae02 --- /dev/null +++ b/test/Pipeline/narrow-memory.forth @@ -0,0 +1,14 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s + +\ Verify narrow memory operations through full pipeline produce a gpu.binary +\ CHECK: gpu.binary @warpforth_module +\! kernel main +\! param DATA i32[256] +GLOBAL-ID CELLS DATA + HF@ +GLOBAL-ID CELLS DATA + HF! +GLOBAL-ID CELLS DATA + BF@ +GLOBAL-ID CELLS DATA + BF! +GLOBAL-ID CELLS DATA + I8@ +GLOBAL-ID CELLS DATA + I8! +GLOBAL-ID CELLS DATA + I16@ +GLOBAL-ID CELLS DATA + I16! diff --git a/test/Pipeline/nested-control-flow.forth b/test/Pipeline/nested-control-flow.forth index 38066f0..8ad2881 100644 --- a/test/Pipeline/nested-control-flow.forth +++ b/test/Pipeline/nested-control-flow.forth @@ -6,10 +6,10 @@ \ Verify intermediate MLIR: gpu.func with nested loop structure \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.func @main(%arg0: memref<4xi32> {forth.param_name = "DATA"}) kernel \ MID: cf.br \ MID: arith.xori \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 3 0 DO 4 0 DO J I + LOOP LOOP DATA 0 CELLS + ! diff --git a/test/Pipeline/plus-loop-negative.forth b/test/Pipeline/plus-loop-negative.forth index 4070e47..bc95839 100644 --- a/test/Pipeline/plus-loop-negative.forth +++ b/test/Pipeline/plus-loop-negative.forth @@ -4,5 +4,5 @@ \ CHECK: gpu.binary @warpforth_module \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 0 10 DO I DATA 0 CELLS + ! -1 +LOOP diff --git a/test/Pipeline/plus-loop.forth b/test/Pipeline/plus-loop.forth index 87e0f9c..6c0a28b 100644 --- a/test/Pipeline/plus-loop.forth +++ b/test/Pipeline/plus-loop.forth @@ -4,5 +4,5 @@ \ CHECK: gpu.binary @warpforth_module \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] 10 0 DO I DATA 0 CELLS + ! 2 +LOOP diff --git a/test/Pipeline/scalar-param.forth b/test/Pipeline/scalar-param.forth index 00771b0..9935826 100644 --- a/test/Pipeline/scalar-param.forth +++ b/test/Pipeline/scalar-param.forth @@ -4,16 +4,16 @@ \ Verify mixed scalar + array params survive the full pipeline \ CHECK: gpu.binary @warpforth_module -\ Verify scalar becomes i64 arg, array becomes memref +\ Verify scalar becomes i32 arg, array becomes memref \ MID: gpu.func @main( -\ MID-SAME: i64 {forth.param_name = "SCALE"} -\ MID-SAME: memref<256xi64> {forth.param_name = "DATA"} +\ MID-SAME: i32 {forth.param_name = "SCALE"} +\ MID-SAME: memref<256xi32> {forth.param_name = "DATA"} \ MID-SAME: kernel \ MID: gpu.return \! kernel main -\! param SCALE i64 -\! param DATA i64[256] +\! param SCALE i32 +\! param DATA i32[256] GLOBAL-ID DUP CELLS DATA + @ SCALE * diff --git a/test/Pipeline/shared-memory.forth b/test/Pipeline/shared-memory.forth index 635e8f1..4afe9c3 100644 --- a/test/Pipeline/shared-memory.forth +++ b/test/Pipeline/shared-memory.forth @@ -6,16 +6,16 @@ \ Verify intermediate MLIR structure: shared alloca becomes workgroup attribution \ MID: gpu.module @warpforth_module -\ MID: gpu.func @main(%arg0: memref<256xi64> {forth.param_name = "DATA"}) -\ MID-SAME: workgroup(%{{.*}}: memref<256xi64, #gpu.address_space>) +\ MID: gpu.func @main(%arg0: memref<256xi32> {forth.param_name = "DATA"}) +\ MID-SAME: workgroup(%{{.*}}: memref<256xi32, #gpu.address_space>) \ MID-SAME: kernel -\ MID: memref.extract_aligned_pointer_as_index %{{.*}} : memref<256xi64, #gpu.address_space> +\ MID: memref.extract_aligned_pointer_as_index %{{.*}} : memref<256xi32, #gpu.address_space> \ MID: llvm.store \ MID: gpu.return \! kernel main -\! param DATA i64[256] -\! shared SCRATCH i64[256] +\! param DATA i32[256] +\! shared SCRATCH i32[256] GLOBAL-ID CELLS SCRATCH + ! GLOBAL-ID CELLS SCRATCH + @ GLOBAL-ID CELLS DATA + ! diff --git a/test/Pipeline/unloop-exit.forth b/test/Pipeline/unloop-exit.forth index 5de0b28..dc97482 100644 --- a/test/Pipeline/unloop-exit.forth +++ b/test/Pipeline/unloop-exit.forth @@ -2,6 +2,6 @@ \ CHECK: gpu.binary @warpforth_module \! kernel main -\! param DATA i64[4] +\! param DATA i32[4] : FIND-FIVE 10 0 DO I 5 = IF UNLOOP EXIT THEN LOOP 0 ; FIND-FIVE DATA 0 CELLS + ! diff --git a/test/Translation/Forth/barrier.forth b/test/Translation/Forth/barrier.forth index 963c38d..5379963 100644 --- a/test/Translation/Forth/barrier.forth +++ b/test/Translation/Forth/barrier.forth @@ -1,5 +1,5 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ CHECK: forth.barrier \! kernel main -\! param DATA i64[256] +\! param DATA i32[256] GLOBAL-ID CELLS DATA + @ BARRIER DROP diff --git a/test/Translation/Forth/basic-literals.forth b/test/Translation/Forth/basic-literals.forth index e739e93..0dee888 100644 --- a/test/Translation/Forth/basic-literals.forth +++ b/test/Translation/Forth/basic-literals.forth @@ -1,8 +1,8 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ CHECK: forth.stack -\ CHECK-NEXT: forth.constant %{{.*}}(42 : i64) -\ CHECK-NEXT: forth.constant %{{.*}}(-7 : i64) -\ CHECK-NEXT: forth.constant %{{.*}}(0 : i64) +\ CHECK-NEXT: forth.constant %{{.*}}(42 : i32) +\ CHECK-NEXT: forth.constant %{{.*}}(-7 : i32) +\ CHECK-NEXT: forth.constant %{{.*}}(0 : i32) \! kernel main 42 -7 0 diff --git a/test/Translation/Forth/begin-until.forth b/test/Translation/Forth/begin-until.forth index ceca20d..fdb7dc2 100644 --- a/test/Translation/Forth/begin-until.forth +++ b/test/Translation/Forth/begin-until.forth @@ -3,10 +3,10 @@ \ Verify BEGIN/UNTIL generates loop with pop_flag + cond_br \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb1(%[[S1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L1:.*]] = forth.constant %[[B1]](1 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L1:.*]] = forth.constant %[[B1]](1 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[SUB:.*]] = forth.subi %[[L1]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[DUP:.*]] = forth.dup %[[SUB]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[ZEQ:.*]] = forth.zero_eq %[[DUP]] : !forth.stack -> !forth.stack diff --git a/test/Translation/Forth/begin-while-repeat.forth b/test/Translation/Forth/begin-while-repeat.forth index 3a9e40d..b775cf9 100644 --- a/test/Translation/Forth/begin-while-repeat.forth +++ b/test/Translation/Forth/begin-while-repeat.forth @@ -3,16 +3,16 @@ \ Verify BEGIN/WHILE/REPEAT generates condition check + body loop with cond_br \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb1(%[[S1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): \ CHECK-NEXT: %[[DUP:.*]] = forth.dup %[[B1]] : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[L0:.*]] = forth.constant %[[DUP]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L0:.*]] = forth.constant %[[DUP]](0 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[GT:.*]] = forth.gti %[[L0]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF:.*]], %[[FLAG:.*]] = forth.pop_flag %[[GT]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FLAG]], ^bb2(%[[PF]] : !forth.stack), ^bb3(%[[PF]] : !forth.stack) \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L1:.*]] = forth.constant %[[B2]](1 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L1:.*]] = forth.constant %[[B2]](1 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[SUB:.*]] = forth.subi %[[L1]] : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb1(%[[SUB]] : !forth.stack) \ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): diff --git a/test/Translation/Forth/control-flow.forth b/test/Translation/Forth/control-flow.forth index 460fbfe..f2874dd 100644 --- a/test/Translation/Forth/control-flow.forth +++ b/test/Translation/Forth/control-flow.forth @@ -4,25 +4,25 @@ \ Basic IF/ELSE/THEN \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](1 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](1 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF1:.*]], %[[FLAG1:.*]] = forth.pop_flag %[[S1]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FLAG1]], ^bb1(%[[PF1]] : !forth.stack), ^bb2(%[[PF1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L42:.*]] = forth.constant %[[B1]](42 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L42:.*]] = forth.constant %[[B1]](42 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb3(%[[L42]] : !forth.stack) \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L99:.*]] = forth.constant %[[B2]](99 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L99:.*]] = forth.constant %[[B2]](99 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb3(%[[L99]] : !forth.stack) \! kernel main 1 IF 42 ELSE 99 THEN \ Basic IF/THEN (no ELSE - fallthrough on false) \ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): -\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[B3]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[B3]](0 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF2:.*]], %[[FLAG2:.*]] = forth.pop_flag %[[S2]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FLAG2]], ^bb4(%[[PF2]] : !forth.stack), ^bb5(%[[PF2]] : !forth.stack) \ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L7:.*]] = forth.constant %[[B4]](7 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L7:.*]] = forth.constant %[[B4]](7 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb5(%[[L7]] : !forth.stack) \ CHECK: ^bb5(%[[B5:.*]]: !forth.stack): \ CHECK-NEXT: return diff --git a/test/Translation/Forth/do-loop.forth b/test/Translation/Forth/do-loop.forth index b922258..923b69d 100644 --- a/test/Translation/Forth/do-loop.forth +++ b/test/Translation/Forth/do-loop.forth @@ -3,28 +3,31 @@ \ Verify DO/LOOP generates post-test loop with crossing test \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 -\ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi64> +\ CHECK-NEXT: %[[TVAL:.*]] = arith.trunci %[[VAL]] : i64 to i32 +\ CHECK-NEXT: %[[TLIM:.*]] = arith.trunci %[[LIM]] : i64 to i32 +\ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi32> \ CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index -\ CHECK-NEXT: memref.store %[[VAL]], %[[ALLOCA]][%[[C0]]] : memref<1xi64> +\ CHECK-NEXT: memref.store %[[TVAL]], %[[ALLOCA]][%[[C0]]] : memref<1xi32> \ CHECK-NEXT: cf.br ^bb1(%[[OS2]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): \ CHECK-NEXT: %[[C0_2:.*]] = arith.constant 0 : index -\ CHECK-NEXT: %[[LOAD1:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi64> -\ CHECK-NEXT: %[[PUSH:.*]] = forth.push_value %[[B1]], %[[LOAD1]] : !forth.stack, i64 -> !forth.stack -\ CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : i64 +\ CHECK-NEXT: %[[LOAD1:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi32> +\ CHECK-NEXT: %[[EXT:.*]] = arith.extsi %[[LOAD1]] : i32 to i64 +\ CHECK-NEXT: %[[PUSH:.*]] = forth.push_value %[[B1]], %[[EXT]] : !forth.stack, i64 -> !forth.stack +\ CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : i32 \ CHECK-NEXT: %[[C0_3:.*]] = arith.constant 0 : index -\ CHECK-NEXT: %[[OLD:.*]] = memref.load %[[ALLOCA]][%[[C0_3]]] : memref<1xi64> -\ CHECK-NEXT: %[[NEW:.*]] = arith.addi %[[OLD]], %[[C1]] : i64 -\ CHECK-NEXT: memref.store %[[NEW]], %[[ALLOCA]][%[[C0_3]]] : memref<1xi64> -\ CHECK-NEXT: %[[D1:.*]] = arith.subi %[[OLD]], %[[LIM]] : i64 -\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %[[LIM]] : i64 -\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i64 -\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64 -\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i64 +\ CHECK-NEXT: %[[OLD:.*]] = memref.load %[[ALLOCA]][%[[C0_3]]] : memref<1xi32> +\ CHECK-NEXT: %[[NEW:.*]] = arith.addi %[[OLD]], %[[C1]] : i32 +\ CHECK-NEXT: memref.store %[[NEW]], %[[ALLOCA]][%[[C0_3]]] : memref<1xi32> +\ CHECK-NEXT: %[[D1:.*]] = arith.subi %[[OLD]], %[[TLIM]] : i32 +\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %[[TLIM]] : i32 +\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i32 +\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32 +\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i32 \ CHECK-NEXT: cf.cond_br %[[CROSSED]], ^bb2(%[[PUSH]] : !forth.stack), ^bb1(%[[PUSH]] : !forth.stack) \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): \ CHECK-NEXT: return diff --git a/test/Translation/Forth/float-literals.forth b/test/Translation/Forth/float-literals.forth index 31d8f16..c0e17ed 100644 --- a/test/Translation/Forth/float-literals.forth +++ b/test/Translation/Forth/float-literals.forth @@ -1,9 +1,9 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ CHECK: forth.stack -\ CHECK-NEXT: forth.constant %{{.*}}(3.140000e+00 : f64) -\ CHECK-NEXT: forth.constant %{{.*}}(-2.000000e+00 : f64) -\ CHECK-NEXT: forth.constant %{{.*}}(1.000000e-05 : f64) -\ CHECK-NEXT: forth.constant %{{.*}}(1.000000e+03 : f64) +\ CHECK-NEXT: forth.constant %{{.*}}(3.{{.*}} : f32) +\ CHECK-NEXT: forth.constant %{{.*}}(-2.{{.*}} : f32) +\ CHECK-NEXT: forth.constant %{{.*}}(9.{{.*}} : f32) +\ CHECK-NEXT: forth.constant %{{.*}}(1.{{.*}} : f32) \! kernel main 3.14 -2.0 1.0e-5 1e3 diff --git a/test/Translation/Forth/float-params.forth b/test/Translation/Forth/float-params.forth index aed8e9e..071b6e3 100644 --- a/test/Translation/Forth/float-params.forth +++ b/test/Translation/Forth/float-params.forth @@ -1,12 +1,12 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Check f64 scalar param becomes f64 function argument -\ CHECK: func.func private @main(%arg0: memref<256xf64> {forth.param_name = "DATA"}, %arg1: f64 {forth.param_name = "SCALE"}) +\ Check f32 scalar param becomes f32 function argument +\ CHECK: func.func private @main(%arg0: memref<256xf32> {forth.param_name = "DATA"}, %arg1: f32 {forth.param_name = "SCALE"}) \ Check param refs work \ CHECK: forth.param_ref %{{.*}} "DATA" \ CHECK: forth.param_ref %{{.*}} "SCALE" \! kernel main -\! param DATA f64[256] -\! param SCALE f64 +\! param DATA f32[256] +\! param SCALE f32 DATA SCALE diff --git a/test/Translation/Forth/header-directive-after-code-error.forth b/test/Translation/Forth/header-directive-after-code-error.forth index 3236bde..f6e1d99 100644 --- a/test/Translation/Forth/header-directive-after-code-error.forth +++ b/test/Translation/Forth/header-directive-after-code-error.forth @@ -1,6 +1,6 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: header directive must appear before any code \! kernel main -\! param A i64[4] +\! param A i32[4] A @ -\! param B i64[4] +\! param B i32[4] diff --git a/test/Translation/Forth/header-duplicate-param-error.forth b/test/Translation/Forth/header-duplicate-param-error.forth index f90bf26..6a0e8cb 100644 --- a/test/Translation/Forth/header-duplicate-param-error.forth +++ b/test/Translation/Forth/header-duplicate-param-error.forth @@ -1,5 +1,5 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: duplicate name: A (already declared as param) \! kernel main -\! param A i64[4] -\! param A i64[8] +\! param A i32[4] +\! param A i32[8] diff --git a/test/Translation/Forth/header-shared-param-duplicate-error.forth b/test/Translation/Forth/header-shared-param-duplicate-error.forth index 249206c..819aba5 100644 --- a/test/Translation/Forth/header-shared-param-duplicate-error.forth +++ b/test/Translation/Forth/header-shared-param-duplicate-error.forth @@ -1,5 +1,5 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: duplicate name: A (already declared as param) \! kernel main -\! param A i64[4] -\! shared A i64[8] +\! param A i32[4] +\! shared A i32[8] diff --git a/test/Translation/Forth/interleaved-control-flow.forth b/test/Translation/Forth/interleaved-control-flow.forth index 6865d05..5be099a 100644 --- a/test/Translation/Forth/interleaved-control-flow.forth +++ b/test/Translation/Forth/interleaved-control-flow.forth @@ -15,7 +15,7 @@ \ Loop header: DUP 10 > → WHILE(1) \ CHECK: ^bb1(%[[H:.*]]: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.constant %{{.*}}(10 : i64) +\ CHECK: forth.constant %{{.*}}(10 : i32) \ CHECK-NEXT: %{{.*}} = forth.gti \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : !forth.stack), ^bb3(%{{.*}} : !forth.stack) @@ -23,7 +23,7 @@ \ WHILE(1) body: DUP 2 MOD 0= → WHILE(2) \ CHECK: ^bb2(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.constant %{{.*}}(2 : i64) +\ CHECK: forth.constant %{{.*}}(2 : i32) \ CHECK-NEXT: %{{.*}} = forth.mod \ CHECK-NEXT: %{{.*}} = forth.zero_eq \ CHECK: forth.pop_flag @@ -35,7 +35,7 @@ \ WHILE(2) body: 1 - → REPEAT (branch back to loop header) \ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.constant %[[B4]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B4]](1 : i32) \ CHECK-NEXT: %{{.*}} = forth.subi \ CHECK-NEXT: cf.br ^bb1 @@ -59,7 +59,7 @@ \ Loop header: DUP 0 > → WHILE \ CHECK: ^bb1(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.constant %{{.*}}(0 : i64) +\ CHECK: forth.constant %{{.*}}(0 : i32) \ CHECK-NEXT: %{{.*}} = forth.gti \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : !forth.stack), ^bb3(%{{.*}} : !forth.stack) @@ -67,10 +67,10 @@ \ WHILE body + UNTIL: 1 - DUP 5 = UNTIL \ UNTIL true exits to ^bb4, UNTIL false loops back to ^bb1 \ CHECK: ^bb2(%[[W:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.constant %[[W]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.constant %[[W]](1 : i32) \ CHECK-NEXT: %{{.*}} = forth.subi \ CHECK: forth.dup -\ CHECK: forth.constant %{{.*}}(5 : i64) +\ CHECK: forth.constant %{{.*}}(5 : i32) \ CHECK-NEXT: %{{.*}} = forth.eqi \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4(%{{.*}} : !forth.stack), ^bb1(%{{.*}} : !forth.stack) diff --git a/test/Translation/Forth/leave.forth b/test/Translation/Forth/leave.forth index 1da159e..0a5aaf1 100644 --- a/test/Translation/Forth/leave.forth +++ b/test/Translation/Forth/leave.forth @@ -3,8 +3,8 @@ \ Verify LEAVE branches to the loop exit block. \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i32) : !forth.stack -> !forth.stack \ CHECK: cf.br ^bb1(%{{.*}} : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): \ CHECK-NEXT: %[[TRUE:.*]] = arith.constant true diff --git a/test/Translation/Forth/local-variables-error-param-conflict.forth b/test/Translation/Forth/local-variables-error-param-conflict.forth index edfb7dc..74ac1bb 100644 --- a/test/Translation/Forth/local-variables-error-param-conflict.forth +++ b/test/Translation/Forth/local-variables-error-param-conflict.forth @@ -1,6 +1,6 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: local variable name 'DATA' conflicts with parameter name \! kernel main -\! param DATA i64[256] +\! param DATA i32[256] : BAD { data -- } data ; BAD diff --git a/test/Translation/Forth/local-variables-error-shared-conflict.forth b/test/Translation/Forth/local-variables-error-shared-conflict.forth index 12b8fba..34b4922 100644 --- a/test/Translation/Forth/local-variables-error-shared-conflict.forth +++ b/test/Translation/Forth/local-variables-error-shared-conflict.forth @@ -1,6 +1,6 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: local variable name 'BUF' conflicts with shared memory name \! kernel main -\! shared BUF i64[64] +\! shared BUF i32[64] : BAD { buf -- } buf ; BAD diff --git a/test/Translation/Forth/memory-ops.forth b/test/Translation/Forth/memory-ops.forth index 0d6afe3..38e1a01 100644 --- a/test/Translation/Forth/memory-ops.forth +++ b/test/Translation/Forth/memory-ops.forth @@ -12,8 +12,9 @@ \ Test S! produces forth.shared_storei \ CHECK: forth.shared_storei %{{.*}} : !forth.stack -> !forth.stack -\ Test CELLS produces literal 8 + mul -\ CHECK: forth.constant %{{.*}}(8 : i64) +\ Test CELLS produces literal 4 + mul +\ CHECK: forth.constant %{{.*}}(4 : i32) +\ CHECK-NEXT: forth.constant %{{.*}}(4 : i32) \ CHECK-NEXT: forth.muli \! kernel main 1 @ 2 3 ! 4 S@ 5 6 S! diff --git a/test/Translation/Forth/narrow-memory.forth b/test/Translation/Forth/narrow-memory.forth new file mode 100644 index 0000000..2451492 --- /dev/null +++ b/test/Translation/Forth/narrow-memory.forth @@ -0,0 +1,58 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Test HF@ produces forth.loadhf +\ CHECK: forth.loadhf %{{.*}} : !forth.stack -> !forth.stack + +\ Test HF! produces forth.storehf +\ CHECK: forth.storehf %{{.*}} : !forth.stack -> !forth.stack + +\ Test BF@ produces forth.loadbf +\ CHECK: forth.loadbf %{{.*}} : !forth.stack -> !forth.stack + +\ Test BF! produces forth.storebf +\ CHECK: forth.storebf %{{.*}} : !forth.stack -> !forth.stack + +\ Test I8@ produces forth.loadi8 +\ CHECK: forth.loadi8 %{{.*}} : !forth.stack -> !forth.stack + +\ Test I8! produces forth.storei8 +\ CHECK: forth.storei8 %{{.*}} : !forth.stack -> !forth.stack + +\ Test I16@ produces forth.loadi16 +\ CHECK: forth.loadi16 %{{.*}} : !forth.stack -> !forth.stack + +\ Test I16! produces forth.storei16 +\ CHECK: forth.storei16 %{{.*}} : !forth.stack -> !forth.stack + +\ Test SHF@ produces forth.shared_loadhf +\ CHECK: forth.shared_loadhf %{{.*}} : !forth.stack -> !forth.stack + +\ Test SHF! produces forth.shared_storehf +\ CHECK: forth.shared_storehf %{{.*}} : !forth.stack -> !forth.stack + +\ Test SBF@ produces forth.shared_loadbf +\ CHECK: forth.shared_loadbf %{{.*}} : !forth.stack -> !forth.stack + +\ Test SBF! produces forth.shared_storebf +\ CHECK: forth.shared_storebf %{{.*}} : !forth.stack -> !forth.stack + +\ Test SI8@ produces forth.shared_loadi8 +\ CHECK: forth.shared_loadi8 %{{.*}} : !forth.stack -> !forth.stack + +\ Test SI8! produces forth.shared_storei8 +\ CHECK: forth.shared_storei8 %{{.*}} : !forth.stack -> !forth.stack + +\ Test SI16@ produces forth.shared_loadi16 +\ CHECK: forth.shared_loadi16 %{{.*}} : !forth.stack -> !forth.stack + +\ Test SI16! produces forth.shared_storei16 +\ CHECK: forth.shared_storei16 %{{.*}} : !forth.stack -> !forth.stack +\! kernel main +1 HF@ 2.0 3 HF! +4 BF@ 5.0 6 BF! +7 I8@ 8 9 I8! +10 I16@ 11 12 I16! +13 SHF@ 14.0 15 SHF! +16 SBF@ 17.0 18 SBF! +19 SI8@ 20 21 SI8! +22 SI16@ 23 24 SI16! diff --git a/test/Translation/Forth/nested-control-flow.forth b/test/Translation/Forth/nested-control-flow.forth index 26cbca1..ac0702f 100644 --- a/test/Translation/Forth/nested-control-flow.forth +++ b/test/Translation/Forth/nested-control-flow.forth @@ -2,11 +2,11 @@ \ === Nested IF === \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](1 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](1 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF1:.*]], %[[FL1:.*]] = forth.pop_flag %[[S1]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FL1]], ^bb1(%[[PF1]] : !forth.stack), ^bb2(%[[PF1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L2:.*]] = forth.constant %[[B1]](2 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L2:.*]] = forth.constant %[[B1]](2 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF2:.*]], %[[FL2:.*]] = forth.pop_flag %[[L2]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FL2]], ^bb3(%[[PF2]] : !forth.stack), ^bb4(%[[PF2]] : !forth.stack) \! kernel main @@ -15,19 +15,21 @@ \ === IF inside DO === \ After IF/THEN merge, set up DO loop: 10 0 DO \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L10:.*]] = forth.constant %[[B2]](10 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[L0A:.*]] = forth.constant %[[L10]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L10:.*]] = forth.constant %[[B2]](10 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L0A:.*]] = forth.constant %[[L10]](0 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[POP1:.*]], %[[V1:.*]] = forth.pop %[[L0A]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[POP2:.*]], %[[V2:.*]] = forth.pop %[[POP1]] : !forth.stack -> !forth.stack, i64 -\ CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloca() : memref<1xi64> +\ CHECK-NEXT: %[[TV1:.*]] = arith.trunci %[[V1]] : i64 to i32 +\ CHECK-NEXT: %[[TV2:.*]] = arith.trunci %[[V2]] : i64 to i32 +\ CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloca() : memref<1xi32> \ CHECK-NEXT: %{{.*}} = arith.constant 0 : index -\ CHECK-NEXT: memref.store %[[V1]], %[[ALLOC1]][%{{.*}}] : memref<1xi64> +\ CHECK-NEXT: memref.store %[[TV1]], %[[ALLOC1]][%{{.*}}] : memref<1xi32> \ CHECK-NEXT: cf.br ^bb5(%[[POP2]] : !forth.stack) 10 0 DO I 5 > IF I THEN LOOP \ Nested IF: true branch pushes 3, then merges \ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L3:.*]] = forth.constant %[[B3]](3 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L3:.*]] = forth.constant %[[B3]](3 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb4(%[[L3]] : !forth.stack) \ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): \ CHECK-NEXT: cf.br ^bb2(%[[B4]] : !forth.stack) @@ -35,7 +37,7 @@ \ DO loop body (post-test: no check block): I 5 > IF I THEN \ CHECK: ^bb5(%[[B5:.*]]: !forth.stack): \ CHECK: forth.push_value %[[B5]] -\ CHECK: forth.constant %{{.*}}(5 : i64) +\ CHECK: forth.constant %{{.*}}(5 : i32) \ CHECK-NEXT: %{{.*}} = forth.gti \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{[^,]*}}, ^bb7(%{{[^)]*}} : !forth.stack), ^bb8(%{{[^)]*}} : !forth.stack) @@ -43,7 +45,7 @@ \ === Nested DO with J === \ After first DO loop exits: sets up nested DO (3 0 DO) \ CHECK: ^bb6(%[[B6:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.constant %[[B6]](3 : i64) +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B6]](3 : i32) 3 0 DO 4 0 DO J I + LOOP LOOP \ IF I true branch: push loop index @@ -61,8 +63,8 @@ \ Outer DO body (3 0 DO) with inner DO setup (4 0 DO) \ CHECK: ^bb9(%{{.*}}: !forth.stack): -\ CHECK: forth.constant %{{.*}}(4 : i64) -\ CHECK: forth.constant %{{.*}}(0 : i64) +\ CHECK: forth.constant %{{.*}}(4 : i32) +\ CHECK: forth.constant %{{.*}}(0 : i32) \ CHECK: forth.pop \ CHECK: forth.pop \ CHECK: memref.alloca() @@ -70,7 +72,7 @@ \ === Triple-nested DO with K === \ After nested DO exits: sets up triple-nested DO (2 0 DO) \ CHECK: ^bb10(%{{.*}}: !forth.stack): -\ CHECK: forth.constant %{{.*}}(2 : i64) +\ CHECK: forth.constant %{{.*}}(2 : i32) 2 0 DO 2 0 DO 2 0 DO K J I + + LOOP LOOP LOOP \ Inner loop of J I + (bb11 body) @@ -93,13 +95,13 @@ \ Triple-nested outer loop body (bb13) \ CHECK: ^bb13(%{{.*}}: !forth.stack): -\ CHECK: forth.constant %{{.*}}(2 : i64) -\ CHECK: forth.constant %{{.*}}(0 : i64) +\ CHECK: forth.constant %{{.*}}(2 : i32) +\ CHECK: forth.constant %{{.*}}(0 : i32) \ === BEGIN/WHILE inside IF === \ After triple-nested exits: 5 IF BEGIN DUP WHILE 1 - REPEAT THEN \ CHECK: ^bb14(%{{.*}}: !forth.stack): -\ CHECK: forth.constant %{{.*}}(5 : i64) +\ CHECK: forth.constant %{{.*}}(5 : i32) \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br 5 IF BEGIN DUP WHILE 1 - REPEAT THEN @@ -120,26 +122,26 @@ \ WHILE body: 1 - \ CHECK: ^bb22(%[[B22:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.constant %[[B22]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B22]](1 : i32) \ CHECK-NEXT: %{{.*}} = forth.subi \ === IF inside BEGIN/UNTIL === \ BEGIN/UNTIL header: DUP 10 < \ CHECK: ^bb24(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.constant %{{.*}}(10 : i64) +\ CHECK: forth.constant %{{.*}}(10 : i32) \ CHECK-NEXT: %{{.*}} = forth.lti BEGIN DUP 10 < IF 1 + THEN DUP 20 = UNTIL \ IF true branch: 1 + \ CHECK: ^bb25(%[[B25:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.constant %[[B25]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B25]](1 : i32) \ CHECK-NEXT: %{{.*}} = forth.addi \ UNTIL condition: DUP 20 = \ CHECK: ^bb26(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.constant %{{.*}}(20 : i64) +\ CHECK: forth.constant %{{.*}}(20 : i32) \ CHECK-NEXT: %{{.*}} = forth.eqi \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br diff --git a/test/Translation/Forth/param-declarations.forth b/test/Translation/Forth/param-declarations.forth index c15fbd7..1659b46 100644 --- a/test/Translation/Forth/param-declarations.forth +++ b/test/Translation/Forth/param-declarations.forth @@ -1,10 +1,10 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ Verify multi-param declarations with correct types and ordering -\ CHECK: func.func private @main(%arg0: memref<256xi64> {forth.param_name = "DATA"}, %arg1: memref<128xi64> {forth.param_name = "WEIGHTS"}) +\ CHECK: func.func private @main(%arg0: memref<256xi32> {forth.param_name = "DATA"}, %arg1: memref<128xi32> {forth.param_name = "WEIGHTS"}) \ CHECK: forth.param_ref %{{.*}} "DATA" \ CHECK: forth.param_ref %{{.*}} "WEIGHTS" \! kernel main -\! param DATA i64[256] -\! param WEIGHTS i64[128] +\! param DATA i32[256] +\! param WEIGHTS i32[128] DATA WEIGHTS diff --git a/test/Translation/Forth/param-ref-in-word-error.forth b/test/Translation/Forth/param-ref-in-word-error.forth index 283e234..08a16d3 100644 --- a/test/Translation/Forth/param-ref-in-word-error.forth +++ b/test/Translation/Forth/param-ref-in-word-error.forth @@ -1,6 +1,6 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: parameter 'DATA' cannot be referenced inside a word definition \! kernel main -\! param DATA i64[256] +\! param DATA i32[256] : BAD-WORD DATA @ ; BAD-WORD diff --git a/test/Translation/Forth/plus-loop-negative.forth b/test/Translation/Forth/plus-loop-negative.forth index 0ef5e94..83e24fa 100644 --- a/test/Translation/Forth/plus-loop-negative.forth +++ b/test/Translation/Forth/plus-loop-negative.forth @@ -3,21 +3,22 @@ \ Verify +LOOP with negative step uses crossing test (handles negative direction) \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](0 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](0 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](10 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 \ CHECK: cf.br ^bb1(%[[OS2]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK: %[[STEP_S:.*]] = forth.constant %[[B1]](-1 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[POP_S:.*]], %[[STEP:.*]] = forth.pop %[[STEP_S]] : !forth.stack -> !forth.stack, i64 +\ CHECK: %[[STEP_S:.*]] = forth.constant %[[B1]](-1 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[POP_S:.*]], %[[STEP64:.*]] = forth.pop %[[STEP_S]] : !forth.stack -> !forth.stack, i64 +\ CHECK-NEXT: %[[STEP:.*]] = arith.trunci %[[STEP64]] : i64 to i32 \ CHECK: %[[OLD:.*]] = memref.load -\ CHECK: %[[NEW:.*]] = arith.addi %[[OLD]], %[[STEP]] : i64 -\ CHECK: %[[D1:.*]] = arith.subi %[[OLD]], %[[LIM]] : i64 -\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %[[LIM]] : i64 -\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i64 -\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64 -\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i64 +\ CHECK: %[[NEW:.*]] = arith.addi %[[OLD]], %[[STEP]] : i32 +\ CHECK: %[[D1:.*]] = arith.subi %[[OLD]], %{{.*}} : i32 +\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %{{.*}} : i32 +\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i32 +\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32 +\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i32 \ CHECK-NEXT: cf.cond_br %[[CROSSED]], ^bb2(%[[POP_S]] : !forth.stack), ^bb1(%[[POP_S]] : !forth.stack) \ CHECK: ^bb2(%{{.*}}: !forth.stack): \ CHECK-NEXT: return diff --git a/test/Translation/Forth/plus-loop.forth b/test/Translation/Forth/plus-loop.forth index ed2f791..bd17dd7 100644 --- a/test/Translation/Forth/plus-loop.forth +++ b/test/Translation/Forth/plus-loop.forth @@ -3,26 +3,29 @@ \ Verify +LOOP pops step from data stack and uses it as increment \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i32) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 -\ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi64> +\ CHECK-NEXT: %[[TVAL:.*]] = arith.trunci %[[VAL]] : i64 to i32 +\ CHECK-NEXT: %[[TLIM:.*]] = arith.trunci %[[LIM]] : i64 to i32 +\ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi32> \ CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index -\ CHECK-NEXT: memref.store %[[VAL]], %[[ALLOCA]][%[[C0]]] : memref<1xi64> +\ CHECK-NEXT: memref.store %[[TVAL]], %[[ALLOCA]][%[[C0]]] : memref<1xi32> \ CHECK-NEXT: cf.br ^bb1(%[[OS2]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK: %[[STEP_S:.*]] = forth.constant %[[B1]](2 : i64) : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[POP_S:.*]], %[[STEP:.*]] = forth.pop %[[STEP_S]] : !forth.stack -> !forth.stack, i64 +\ CHECK: %[[STEP_S:.*]] = forth.constant %[[B1]](2 : i32) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[POP_S:.*]], %[[STEP64:.*]] = forth.pop %[[STEP_S]] : !forth.stack -> !forth.stack, i64 +\ CHECK-NEXT: %[[STEP:.*]] = arith.trunci %[[STEP64]] : i64 to i32 \ CHECK-NEXT: %[[C0_2:.*]] = arith.constant 0 : index -\ CHECK-NEXT: %[[OLD:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi64> -\ CHECK-NEXT: %[[NEW:.*]] = arith.addi %[[OLD]], %[[STEP]] : i64 -\ CHECK-NEXT: memref.store %[[NEW]], %[[ALLOCA]][%[[C0_2]]] : memref<1xi64> -\ CHECK-NEXT: %[[D1:.*]] = arith.subi %[[OLD]], %[[LIM]] : i64 -\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %[[LIM]] : i64 -\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i64 -\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64 -\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i64 +\ CHECK-NEXT: %[[OLD:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi32> +\ CHECK-NEXT: %[[NEW:.*]] = arith.addi %[[OLD]], %[[STEP]] : i32 +\ CHECK-NEXT: memref.store %[[NEW]], %[[ALLOCA]][%[[C0_2]]] : memref<1xi32> +\ CHECK-NEXT: %[[D1:.*]] = arith.subi %[[OLD]], %[[TLIM]] : i32 +\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %[[TLIM]] : i32 +\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i32 +\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32 +\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i32 \ CHECK-NEXT: cf.cond_br %[[CROSSED]], ^bb2(%[[POP_S]] : !forth.stack), ^bb1(%[[POP_S]] : !forth.stack) \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): \ CHECK-NEXT: return diff --git a/test/Translation/Forth/scalar-param.forth b/test/Translation/Forth/scalar-param.forth index ec99aca..b0456f8 100644 --- a/test/Translation/Forth/scalar-param.forth +++ b/test/Translation/Forth/scalar-param.forth @@ -1,8 +1,8 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Verify scalar param uses i64 argument type. -\ CHECK: func.func private @main(%arg0: i64 {forth.param_name = "SCALE"}) +\ Verify scalar param uses i32 argument type. +\ CHECK: func.func private @main(%arg0: i32 {forth.param_name = "SCALE"}) \ CHECK: forth.param_ref %{{.*}} "SCALE" \! kernel main -\! param SCALE i64 +\! param SCALE i32 SCALE diff --git a/test/Translation/Forth/shared-declarations.forth b/test/Translation/Forth/shared-declarations.forth index 065b53a..5f3aac8 100644 --- a/test/Translation/Forth/shared-declarations.forth +++ b/test/Translation/Forth/shared-declarations.forth @@ -1,12 +1,12 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ Verify shared memory declarations produce tagged alloca and pointer push sequence -\ CHECK: func.func private @main(%arg0: memref<256xi64> {forth.param_name = "DATA"}) -\ CHECK: memref.alloca() {forth.shared_name = "SCRATCH"} : memref<256xi64> +\ CHECK: func.func private @main(%arg0: memref<256xi32> {forth.param_name = "DATA"}) +\ CHECK: memref.alloca() {forth.shared_name = "SCRATCH"} : memref<256xi32> \ CHECK: memref.extract_aligned_pointer_as_index \ CHECK: arith.index_cast \ CHECK: forth.push_value \! kernel main -\! param DATA i64[256] -\! shared SCRATCH i64[256] +\! param DATA i32[256] +\! shared SCRATCH i32[256] SCRATCH diff --git a/test/Translation/Forth/shared-ref-in-word-error.forth b/test/Translation/Forth/shared-ref-in-word-error.forth index 5f7b3f0..92d3769 100644 --- a/test/Translation/Forth/shared-ref-in-word-error.forth +++ b/test/Translation/Forth/shared-ref-in-word-error.forth @@ -1,6 +1,6 @@ \ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s \ CHECK: shared memory 'SCRATCH' cannot be referenced inside a word definition \! kernel main -\! shared SCRATCH i64[256] +\! shared SCRATCH i32[256] : BAD-WORD SCRATCH @ ; BAD-WORD diff --git a/tools/warpforth-runner/warpforth-runner.cpp b/tools/warpforth-runner/warpforth-runner.cpp index 67e1d3f..1b8041f 100644 --- a/tools/warpforth-runner/warpforth-runner.cpp +++ b/tools/warpforth-runner/warpforth-runner.cpp @@ -5,7 +5,7 @@ /// -std=c++17`. /// /// Usage: -/// warpforth-runner kernel.ptx --param i64[]:1,2,3 --param f64:3.14 \ +/// warpforth-runner kernel.ptx --param i32[]:1,2,3 --param f32:3.14 \ /// --grid 4,1,1 --block 64,1,1 --kernel main \ /// --output-param 0 --output-count 3 @@ -46,8 +46,8 @@ template struct ScalarParam { T value; }; -using Param = std::variant, ArrayParam, - ScalarParam, ScalarParam>; +using Param = std::variant, ArrayParam, + ScalarParam, ScalarParam>; template static void allocDevice(ArrayParam &arr) { size_t bytes = arr.values.size() * sizeof(T); @@ -64,7 +64,7 @@ static void printOutput(ArrayParam &arr, size_t count) { if (i > 0) std::cout << ","; if constexpr (std::is_floating_point_v) - std::cout << std::setprecision(17) << output[i]; + std::cout << std::setprecision(9) << output[i]; else std::cout << output[i]; } @@ -72,18 +72,18 @@ static void printOutput(ArrayParam &arr, size_t count) { } static void *kernelArgPtr(Param &p) { - if (auto *a = std::get_if>(&p)) + if (auto *a = std::get_if>(&p)) return &a->devicePtr; - if (auto *a = std::get_if>(&p)) + if (auto *a = std::get_if>(&p)) return &a->devicePtr; - if (auto *s = std::get_if>(&p)) + if (auto *s = std::get_if>(&p)) return &s->value; - return &std::get>(p).value; + return &std::get>(p).value; } static bool isScalar(const Param &p) { - return std::holds_alternative>(p) || - std::holds_alternative>(p); + return std::holds_alternative>(p) || + std::holds_alternative>(p); } struct Dims { @@ -131,8 +131,8 @@ static Param parseParam(std::string_view s) { auto colonPos = input.find(':'); if (colonPos == std::string::npos) { - std::cerr << "Error: --param requires type prefix (e.g. i64:42 or " - "f64[]:1.0,2.0), got: " + std::cerr << "Error: --param requires type prefix (e.g. i32:42 or " + "f32[]:1.0,2.0), got: " << s << "\n"; exit(1); } @@ -157,18 +157,18 @@ static Param parseParam(std::string_view s) { return vals; }; - auto toI64 = [&](const std::string &tok) -> int64_t { + auto toI32 = [&](const std::string &tok) -> int32_t { try { - return std::stoll(tok); + return std::stoi(tok); } catch (const std::exception &) { std::cerr << "Error: invalid integer value '" << tok << "' in --param " << s << "\n"; exit(1); } }; - auto toF64 = [&](const std::string &tok) -> double { + auto toF32 = [&](const std::string &tok) -> float { try { - return std::stod(tok); + return std::stof(tok); } catch (const std::exception &) { std::cerr << "Error: invalid float value '" << tok << "' in --param " << s << "\n"; @@ -176,10 +176,10 @@ static Param parseParam(std::string_view s) { } }; - if (typePrefix == "i64[]") - return Param{ArrayParam{parseValues(toI64)}}; - if (typePrefix == "f64[]") - return Param{ArrayParam{parseValues(toF64)}}; + if (typePrefix == "i32[]") + return Param{ArrayParam{parseValues(toI32)}}; + if (typePrefix == "f32[]") + return Param{ArrayParam{parseValues(toF32)}}; // Scalars — must be exactly one value if (valueStr.find(',') != std::string::npos) { @@ -188,13 +188,13 @@ static Param parseParam(std::string_view s) { exit(1); } - if (typePrefix == "i64") - return Param{ScalarParam{toI64(valueStr)}}; - if (typePrefix == "f64") - return Param{ScalarParam{toF64(valueStr)}}; + if (typePrefix == "i32") + return Param{ScalarParam{toI32(valueStr)}}; + if (typePrefix == "f32") + return Param{ScalarParam{toF32(valueStr)}}; std::cerr << "Error: unsupported param type '" << typePrefix - << "' (expected i64, i64[], f64, or f64[]), got: " << s << "\n"; + << "' (expected i32, i32[], f32, or f32[]), got: " << s << "\n"; exit(1); } @@ -254,8 +254,8 @@ int main(int argc, char **argv) { if (!ptxFile) { std::cerr << "Usage: warpforth-runner kernel.ptx --kernel NAME " - "[--param i64[]:V,...] [--param f64[]:V,...] " - "[--param i64:V] [--param f64:V] [--grid X,Y,Z] " + "[--param i32[]:V,...] [--param f32[]:V,...] " + "[--param i32:V] [--param f32:V] [--grid X,Y,Z] " "[--block X,Y,Z] [--output-param N] [--output-count N]\n"; return 1; } @@ -303,9 +303,9 @@ int main(int argc, char **argv) { // Allocate device buffers for array params for (auto &p : params) { - if (auto *a = std::get_if>(&p)) + if (auto *a = std::get_if>(&p)) allocDevice(*a); - else if (auto *a = std::get_if>(&p)) + else if (auto *a = std::get_if>(&p)) allocDevice(*a); } @@ -322,12 +322,12 @@ int main(int argc, char **argv) { // Copy back and print output param size_t count = outputCount >= 0 ? static_cast(outputCount) : 0; - if (auto *iArr = std::get_if>(¶ms[outputParam])) { + if (auto *iArr = std::get_if>(¶ms[outputParam])) { if (outputCount < 0) count = iArr->values.size(); printOutput(*iArr, count); } else { - auto &fArr = std::get>(params[outputParam]); + auto &fArr = std::get>(params[outputParam]); if (outputCount < 0) count = fArr.values.size(); printOutput(fArr, count); @@ -335,9 +335,9 @@ int main(int argc, char **argv) { // Cleanup — only free device memory for array params for (auto &p : params) { - if (auto *a = std::get_if>(&p)) + if (auto *a = std::get_if>(&p)) cuMemFree(a->devicePtr); - else if (auto *a = std::get_if>(&p)) + else if (auto *a = std::get_if>(&p)) cuMemFree(a->devicePtr); } cuModuleUnload(module);