From 1c0c3e929b71de53cd7a17e673d91a60f62f1fcc Mon Sep 17 00:00:00 2001 From: Stevan Milic Date: Mon, 27 Apr 2026 21:14:40 +0000 Subject: [PATCH] feat: fusefmt tool --- examples/fusefmt.fuse | 455 ++++++++++++ examples/lambda_calc.fuse | 15 - examples/list.fuse | 12 +- examples/option.fuse | 12 +- examples/untyped_lambda.fuse | 7 +- grin/prim_ops.c | 39 +- grin/prim_ops.h | 8 +- grin/runtime.c | 8 +- src/main/scala/Compiler.scala | 52 +- src/main/scala/Fuse.scala | 63 +- src/main/scala/code/Grin.scala | 213 +++++- src/main/scala/code/GrinPrelude.scala | 16 +- src/main/scala/code/GrinUtils.scala | 161 +++-- src/main/scala/code/MonoDriver.scala | 24 +- src/main/scala/code/MonoRewrite.scala | 186 +++-- src/main/scala/code/MonoSpecialize.scala | 196 ++++-- src/main/scala/code/Syntax.scala | 32 +- src/main/scala/core/Context.scala | 31 +- src/main/scala/core/Desugar.scala | 55 +- src/main/scala/core/DesugarError.scala | 6 + src/main/scala/core/Instantiations.scala | 86 ++- src/main/scala/core/Primops.scala | 36 + src/main/scala/core/Representation.scala | 18 +- src/main/scala/core/Shifting.scala | 20 +- src/main/scala/core/Syntax.scala | 23 +- src/main/scala/core/TypeChecker.scala | 270 +++++++- src/main/scala/parser/Expressions.scala | 8 +- src/test/scala/CompilerTests.scala | 847 +++++++++++++++++++---- stdlib/io.fuse | 17 + stdlib/list.fuse | 39 ++ stdlib/option.fuse | 37 + stdlib/string.fuse | 37 + 32 files changed, 2597 insertions(+), 432 deletions(-) create mode 100644 examples/fusefmt.fuse create mode 100644 stdlib/list.fuse create mode 100644 stdlib/option.fuse create mode 100644 stdlib/string.fuse diff --git a/examples/fusefmt.fuse b/examples/fusefmt.fuse new file mode 100644 index 0000000..4cfb79b --- /dev/null +++ b/examples/fusefmt.fuse @@ -0,0 +1,455 @@ +type Token: + TKeyword(name: str) + TIdent(name: str) + TInt(value: i32) + TStr(value: str) + TOp(text: str) + TPunct(c: str) + TComment(text: str) + TNewline + TIndent(width: i32) + TEof + +type ScanStep: + StepEmit(np: i32, ls: bool, tok: Token) + StepSkip(np: i32) + StepEnd + +fun str_eq_at(s: str, t: str, i: i32, n: i32) -> bool + match i >= n: + true => true + false => { + match char_at(s, i) == char_at(t, i): + true => str_eq_at(s, t, i + 1, n) + false => false + } + +fun str_eq(s: str, t: str) -> bool + let sl = str_len(s) + let tl = str_len(t) + match sl == tl: + true => str_eq_at(s, t, 0, sl) + false => false + +fun is_keyword(s: str) -> bool + str_eq(s, "fun") || str_eq(s, "let") || str_eq(s, "match") || str_eq(s, "type") || str_eq(s, "impl") || str_eq(s, "trait") || str_eq(s, "do") || str_eq(s, "if") || str_eq(s, "import") || str_eq(s, "for") || str_eq(s, "true") || str_eq(s, "false") || str_eq(s, "Unit") || str_eq(s, "self") + +fun two_char_op(c1: i32, c2: i32) -> bool + (c1 == 61 && c2 == 62) || (c1 == 45 && c2 == 62) || (c1 == 60 && c2 == 45) || (c1 == 61 && c2 == 61) || (c1 == 33 && c2 == 61) || (c1 == 60 && c2 == 61) || (c1 == 62 && c2 == 61) || (c1 == 38 && c2 == 38) || (c1 == 124 && c2 == 124) || (c1 == 58 && c2 == 58) + +fun is_one_char_op(c: i32) -> bool + c == 61 || c == 60 || c == 62 || c == 43 || c == 45 || c == 42 || c == 47 || c == 37 || c == 33 + +fun is_punct_char(c: i32) -> bool + c == 40 || c == 41 || c == 91 || c == 93 || c == 123 || c == 125 || c == 44 || c == 58 || c == 46 || c == 59 + +fun parse_int_chars(src: str, pos: i32, end: i32, acc: i32) -> i32 + match pos >= end: + true => acc + false => parse_int_chars(src, pos + 1, end, acc * 10 + (char_at(src, pos) - 48)) + +fun count_indent_end(src: str, pos: i32, len: i32) -> i32 + match pos >= len: + true => pos + false => { + match char_at(src, pos) == 32: + true => count_indent_end(src, pos + 1, len) + false => pos + } + +fun scan_ident_end(src: str, pos: i32, len: i32) -> i32 + match pos >= len: + true => pos + false => { + match is_ident_cont(char_at(src, pos)): + true => scan_ident_end(src, pos + 1, len) + false => pos + } + +fun scan_int_end(src: str, pos: i32, len: i32) -> i32 + match pos >= len: + true => pos + false => { + match is_digit(char_at(src, pos)): + true => scan_int_end(src, pos + 1, len) + false => pos + } + +fun scan_string_end(src: str, pos: i32, len: i32) -> i32 + match pos >= len: + true => pos + false => { + match char_at(src, pos) == 34: + true => pos + false => { + match char_at(src, pos) == 92 && pos + 1 < len: + true => { + let nx = char_at(src, pos + 1) + match nx == 92 || nx == 34: + true => scan_string_end(src, pos + 2, len) + false => scan_string_end(src, pos + 1, len) + } + false => scan_string_end(src, pos + 1, len) + } + } + +fun scan_comment_end(src: str, pos: i32, len: i32) -> i32 + match pos >= len: + true => pos + false => { + match char_at(src, pos) == 10: + true => pos + false => scan_comment_end(src, pos + 1, len) + } + +fun classify_ident(text: str) -> Token + match is_keyword(text): + true => TKeyword(text) + false => TIdent(text) + +fun peek_at(src: str, pos: i32, len: i32) -> i32 + match pos < len: + true => char_at(src, pos) + false => 0 + +fun advance_after_string(e: i32, len: i32) -> i32 + match e < len: + true => e + 1 + false => e + +fun step_indent(src: str, pos: i32, len: i32) -> ScanStep + let e = count_indent_end(src, pos, len) + let n = e - pos + StepEmit(e, false, TIndent(n)) + +fun step_comment(src: str, pos: i32, len: i32) -> ScanStep + let e = scan_comment_end(src, pos + 1, len) + let txt = substring(src, pos + 1, e) + StepEmit(e, false, TComment(txt)) + +fun step_string(src: str, pos: i32, len: i32) -> ScanStep + let e = scan_string_end(src, pos + 1, len) + let txt = substring(src, pos + 1, e) + let np = advance_after_string(e, len) + StepEmit(np, false, TStr(txt)) + +fun step_int(src: str, pos: i32, len: i32) -> ScanStep + let e = scan_int_end(src, pos, len) + let v = parse_int_chars(src, pos, e, 0) + StepEmit(e, false, TInt(v)) + +fun step_ident(src: str, pos: i32, len: i32) -> ScanStep + let e = scan_ident_end(src, pos + 1, len) + let txt = substring(src, pos, e) + StepEmit(e, false, classify_ident(txt)) + +fun step_op_or_punct(src: str, pos: i32, len: i32, c: i32) -> ScanStep + let c2 = peek_at(src, pos + 1, len) + match two_char_op(c, c2): + true => StepEmit(pos + 2, false, TOp(substring(src, pos, pos + 2))) + false => { + match is_one_char_op(c): + true => StepEmit(pos + 1, false, TOp(substring(src, pos, pos + 1))) + false => { + match is_punct_char(c): + true => StepEmit(pos + 1, false, TPunct(substring(src, pos, pos + 1))) + false => StepSkip(pos + 1) + } + } + +fun step_body(src: str, pos: i32, len: i32, c: i32) -> ScanStep + match c == 10: + true => StepEmit(pos + 1, true, TNewline) + false => { + match c == 32 || c == 9: + true => StepSkip(pos + 1) + false => { + match c == 35: + true => step_comment(src, pos, len) + false => { + match c == 34: + true => step_string(src, pos, len) + false => { + match is_digit(c): + true => step_int(src, pos, len) + false => { + match is_ident_start(c): + true => step_ident(src, pos, len) + false => step_op_or_punct(src, pos, len, c) + } + } + } + } + } + +fun next_step(src: str, pos: i32, len: i32, line_start: bool) -> ScanStep + match pos >= len: + true => StepEnd + false => { + match line_start: + true => step_indent(src, pos, len) + false => step_body(src, pos, len, char_at(src, pos)) + } + +fun scan(src: str, pos: i32, len: i32, line_start: bool, acc: List[Token]) -> List[Token] + match next_step(src, pos, len, line_start): + StepEnd => Cons(TEof, acc).reverse() + StepSkip(np) => scan(src, np, len, false, acc) + StepEmit(np, ls, tok) => scan(src, np, len, ls, Cons(tok, acc)) + +fun lex(src: str) -> List[Token] + scan(src, 0, str_len(src), true, Nil[Token]) + +type Node: + NTok(t: Token) + NGroup(items: List[Node]) + NLine(indent: i32, items: List[Node]) + NBlank + +type ParseRes: + PR(rem: List[Token], items: List[Node]) + +fun open_to_close(c: str) -> str + match str_eq(c, "("): + true => ")" + false => { + match str_eq(c, "["): + true => "]" + false => "}" + } + +impl Token: + fun is_open_punct(self) -> bool + match self: + TPunct(c) => str_eq(c, "(") || str_eq(c, "[") || str_eq(c, "{") + _ => false + + fun is_close_punct(self) -> bool + match self: + TPunct(c) => str_eq(c, ")") || str_eq(c, "]") || str_eq(c, "}") + _ => false + + fun punct_text(self) -> str + match self: + TPunct(c) => c + _ => "" + + fun matching_close(self) -> str + match self: + TPunct(c) => open_to_close(c) + _ => "" + + fun matches_close(self, expected: str) -> bool + match str_len(expected) > 0: + true => self.is_close_punct() && str_eq(self.punct_text(), expected) + false => false + + fun is_no_lead_punct(self) -> bool + match self: + TPunct(c) => str_eq(c, ",") || str_eq(c, ";") || str_eq(c, ":") || str_eq(c, ".") || str_eq(c, ")") || str_eq(c, "]") || str_eq(c, "}") + TOp(s) => str_eq(s, "::") + _ => false + + fun is_no_trail_punct(self) -> bool + match self: + TPunct(c) => str_eq(c, "(") || str_eq(c, "[") || str_eq(c, "{") || str_eq(c, ".") + TOp(s) => str_eq(s, "::") + _ => false + + fun is_whitespace(self) -> bool + match self: + TNewline => true + TIndent(_) => true + _ => false + + fun is_ident(self) -> bool + match self: + TIdent(_) => true + _ => false + + fun is_call_open(self) -> bool + match self: + TPunct(c) => str_eq(c, "(") || str_eq(c, "[") + _ => false + + fun is_call_target_close(self) -> bool + match self: + TPunct(c) => str_eq(c, ")") || str_eq(c, "]") + _ => false + +fun group_seq(tokens: List[Token], close: str, acc: List[Node]) -> ParseRes + match tokens: + Nil => PR(Nil[Token], acc.reverse()) + Cons(h, rest) => { + match h.matches_close(close): + true => PR(rest, Cons(NTok(h), acc).reverse()) + false => { + match h.is_open_punct(): + true => { + let mc = h.matching_close() + match group_seq(rest, mc, Cons(NTok(h), Nil[Node])): + PR(rem, inner) => group_seq(rem, close, Cons(NGroup(inner), acc)) + } + false => group_seq(rest, close, Cons(NTok(h), acc)) + } + } + +fun group_all(tokens: List[Token]) -> List[Node] + match group_seq(tokens, "", Nil): + PR(_, items) => items + +fun close_line(indent: i32, items: List[Node]) -> Node + match items: + Nil => NBlank + Cons(_, _) => NLine(indent, items.reverse()) + +fun slice_lines_at(nodes: List[Node], ci: i32, ci_items: List[Node], acc: List[Node]) -> List[Node] + match nodes: + Nil => { + match ci_items: + Nil => acc.reverse() + Cons(_, _) => Cons(close_line(ci, ci_items), acc).reverse() + } + Cons(h, rest) => { + match h: + NTok(t) => { + match t: + TNewline => slice_lines_at(rest, 0, Nil, Cons(close_line(ci, ci_items), acc)) + TIndent(n) => slice_lines_at(rest, n, ci_items, acc) + TEof => slice_lines_at(rest, ci, ci_items, acc) + _ => slice_lines_at(rest, ci, Cons(h, ci_items), acc) + } + _ => slice_lines_at(rest, ci, Cons(h, ci_items), acc) + } + +fun slice_lines(nodes: List[Node]) -> List[Node] + slice_lines_at(nodes, 0, Nil, Nil) + +fun parse(tokens: List[Token]) -> List[Node] + slice_lines(group_all(tokens)) + +fun round_indent(n: i32) -> i32 + let level = n / 4 + let extra = n % 4 + match extra >= 2: + true => (level + 1) * 4 + false => level * 4 + +fun spaces(n: i32) -> str + match n <= 0: + true => "" + false => " " + spaces(n - 1) + +fun should_space(prev: Token, curr: Token) -> bool + match curr.is_whitespace(): + true => false + false => { + match prev.is_whitespace(): + true => false + false => { + match curr.is_no_lead_punct(): + true => false + false => { + match prev.is_no_trail_punct(): + true => false + false => { + match (prev.is_ident() || prev.is_call_target_close()) && curr.is_call_open(): + true => false + false => true + } + } + } + } + +type EmitState: + ES(text: str, prev: Option[Token]) + +fun space_before(prev: Option[Token], curr: Token) -> str + match prev: + None => "" + Some(p) => { + match should_space(p, curr): + true => " " + false => "" + } + +fun emit_token(t: Token) -> str + match t: + TKeyword(n) => n + TIdent(n) => n + TInt(v) => int_to_str(v) + TStr(s) => "\"" + s + "\"" + TOp(s) => s + TPunct(c) => c + TComment(s) => "#" + s + TNewline => "\n" + TIndent(n) => spaces(round_indent(n)) + TEof => "" + +fun emit_inline(items: List[Node], prev: Option[Token]) -> EmitState + match items: + Nil => ES("", prev) + Cons(h, t) => { + match h: + NTok(tk) => { + let sep = space_before(prev, tk) + let txt = emit_token(tk) + match emit_inline(t, Some(tk)): + ES(rest, fp) => ES(sep + txt + rest, fp) + } + NGroup(inner) => { + let g: List[Node] = inner + match emit_inline(g, prev): + ES(itext, after) => { + let p: Option[Token] = after + match emit_inline(t, p): + ES(rest, fp) => ES(itext + rest, fp) + } + } + _ => emit_inline(t, prev) + } + +fun emit_line_str(items: List[Node]) -> str + match emit_inline(items, None): + ES(text, _) => text + +fun emit_doc(lines: List[Node], blank_run: i32, acc: str) -> str + match lines: + Nil => acc + Cons(h, t) => { + match h: + NLine(ind, items) => emit_doc(t, 0, acc + spaces(round_indent(ind)) + emit_line_str(items) + "\n") + NBlank => { + match blank_run >= 2: + true => emit_doc(t, blank_run + 1, acc) + false => emit_doc(t, blank_run + 1, acc + "\n") + } + NTok(_) => emit_doc(t, blank_run, acc) + NGroup(_) => emit_doc(t, blank_run, acc) + } + +fun format(src: str) -> str + emit_doc(parse(lex(src)), 0, "") + +fun run_format(path: str) -> IO[i32] + do: + src <- read(path) + _ <- print(format(src)) + 0 + +fun run_format_stdin() -> IO[i32] + do: + src <- read_stdin() + _ <- print(format(src)) + 0 + +fun choose_action(args: List[str]) -> IO[i32] + match args.length() >= 1: + true => run_format(args.head_or("")) + false => run_format_stdin() + +fun main() -> IO[i32] + do: + args <- get_args() + r <- choose_action(args) + r diff --git a/examples/lambda_calc.fuse b/examples/lambda_calc.fuse index 20f440f..8c33539 100644 --- a/examples/lambda_calc.fuse +++ b/examples/lambda_calc.fuse @@ -2,21 +2,6 @@ type List[T]: Cons(h: T, t: List[T]) Nil -type Option[T]: - None - Some(T) - -impl Option[T]: - fun is_some(self) -> bool - match self: - Some(v) => true - _ => false - - fun is_none(self) -> bool - match self: - Some(v) => false - _ => true - type Tuple[A, B](A, B) type Type: diff --git a/examples/list.fuse b/examples/list.fuse index 1d83f98..a5375fc 100644 --- a/examples/list.fuse +++ b/examples/list.fuse @@ -1,7 +1,3 @@ -type List[A]: - Cons(h: A, t: List[A]) - Nil - trait Functor[A]: fun map[B](self, f: A -> B) -> Self[B]; @@ -39,7 +35,7 @@ impl Functor[A] for List[A]: fun fmap[A, B, T: Functor](f: A -> B, c: T[A]) -> T[B] c.map(f) -fun main() -> i32 +fun main() -> IO[Unit] let l = Cons(2, Cons(3, Nil)) let l1 = fmap(v => v + 1, l) let l2 = Cons(7, Nil) @@ -47,7 +43,5 @@ fun main() -> i32 let l4 = l3.filter(e => e > 3) let s = List::sum(l4) let p = List::product(l4) - let io = print(int_to_str(s + p)) - io.exec() - 0 - + print(int_to_str(s + p)) + diff --git a/examples/option.fuse b/examples/option.fuse index c51571d..4669bba 100644 --- a/examples/option.fuse +++ b/examples/option.fuse @@ -1,19 +1,9 @@ -type Option[A]: - None - Some(A) - -impl Option[A]: - fun map[B](self, f: A -> B) -> Option[B] - match self: - Some(v) => Some(f(v)) - _ => None - fun main() -> i32 let o = Some(5) let o1 = o.map(a => a + 1) match o1: Some(v) => { - print(int_to_str(v)) + print(int_to_str(v)).exec() 0 } None => 1 diff --git a/examples/untyped_lambda.fuse b/examples/untyped_lambda.fuse index c47d2f6..741f344 100644 --- a/examples/untyped_lambda.fuse +++ b/examples/untyped_lambda.fuse @@ -1,7 +1,3 @@ -type Option[T]: - None - Some(T) - type Term: Var(index: i32, ctxlen: i32) Abs(hint: str, body: Term) @@ -71,8 +67,7 @@ fun term_to_str(t: Term) -> str App(t1, t2) => "(" + term_to_str(t1) + " " + term_to_str(t2) + ")" fun println(s: str) -> Unit - print(s + "\n") - () + print(s + "\n").exec() fun main() -> i32 let id = Abs("x", Var(0, 1)) diff --git a/grin/prim_ops.c b/grin/prim_ops.c index 21b6be5..fb0f1e0 100644 --- a/grin/prim_ops.c +++ b/grin/prim_ops.c @@ -75,7 +75,7 @@ int64_t _prim_int_print(int64_t p1) { return 0; } -struct string* _prim_read_string() { +struct string* _prim_read_string(int64_t unit) { char *buffer = NULL; size_t len = 0; size_t read; @@ -305,3 +305,40 @@ int64_t _prim_bool_or(int64_t p1, int64_t p2) { int64_t _prim_string_ne(struct string* p1, struct string* p2) { return !_prim_string_eq(p1, p2); } + +extern int g_argc; +extern char** g_argv; + +int64_t _prim_args_count(int64_t unit) { + (void)unit; + return (int64_t)(g_argc > 0 ? g_argc - 1 : 0); +} + +struct string* _prim_args_get(int64_t idx) { + int64_t total = (int64_t)(g_argc > 0 ? g_argc - 1 : 0); + if (idx < 0 || idx >= total) { + struct string* msg = create_string_copy("args index out of bounds"); + _prim_error(msg); + return create_string_len(0); + } + return create_string_copy(g_argv[idx + 1]); +} + +int64_t _prim_string_char_at(struct string* s, int64_t idx) { + if (idx < 0 || idx >= s->length) { + struct string* msg = create_string_copy("string index out of bounds"); + _prim_error(msg); + return 0; + } + return (int64_t)(unsigned char)s->data[idx]; +} + +struct string* _prim_string_substring(struct string* s, int64_t start, int64_t end) { + if (start < 0) start = 0; + if (end > s->length) end = s->length; + if (start >= end) return create_string_len(0); + int64_t len = end - start; + struct string* r = create_string_len(len); + memcpy(r->data, s->data + start, len); + return r; +} diff --git a/grin/prim_ops.h b/grin/prim_ops.h index 2998ac3..3016c55 100644 --- a/grin/prim_ops.h +++ b/grin/prim_ops.h @@ -17,7 +17,7 @@ void cstring(char* buffer, struct string* s); int64_t _prim_string_print(struct string* p1); int64_t _prim_int_print(int64_t p1); -struct string* _prim_read_string(); +struct string* _prim_read_string(int64_t unit); int64_t _prim_usleep(int64_t p1); int64_t _prim_error(struct string* p1); int64_t _prim_ffi_file_eof(int64_t p1); @@ -45,3 +45,9 @@ float _prim_float_mod(float p1, float p2); int64_t _prim_bool_and(int64_t p1, int64_t p2); int64_t _prim_bool_or(int64_t p1, int64_t p2); int64_t _prim_string_ne(struct string* p1, struct string* p2); + +int64_t _prim_args_count(int64_t unit); +struct string* _prim_args_get(int64_t idx); + +int64_t _prim_string_char_at(struct string* s, int64_t idx); +struct string* _prim_string_substring(struct string* s, int64_t start, int64_t end); diff --git a/grin/runtime.c b/grin/runtime.c index 4859a2d..f2c5960 100644 --- a/grin/runtime.c +++ b/grin/runtime.c @@ -8,13 +8,19 @@ extern int64_t _heap_ptr_; #endif +int g_argc = 0; +char** g_argv = NULL; + int64_t grinMain(); void __runtime_error(int64_t c){ exit(c); } -int main() { +int main(int argc, char** argv) { + g_argc = argc; + g_argv = argv; + #ifdef USE_BOEHM_GC GC_INIT(); #else diff --git a/src/main/scala/Compiler.scala b/src/main/scala/Compiler.scala index b5ce6a2..6c4b1dd 100644 --- a/src/main/scala/Compiler.scala +++ b/src/main/scala/Compiler.scala @@ -10,6 +10,7 @@ import org.parboiled2.* import parser.FuseParser import parser.FuseParser.* import parser.ParserErrorFormatter +import parser.Types.* import java.io.* import java.nio.file.{Files, Path, Paths} @@ -66,11 +67,11 @@ object Compiler { */ def readStdlibFiles(dir: Path): IO[List[(Path, String)]] = IO.blocking { import scala.jdk.CollectionConverters.* - val stream = Files.newDirectoryStream(dir, "*.fuse") + val stream = Files.newDirectoryStream(dir) try - stream.asScala.toList.map(p => - p -> new String(Files.readAllBytes(p)).trim - ) + stream.asScala.toList + .filter(p => p.getFileName.toString.endsWith(".fuse")) + .map(p => p -> new String(Files.readAllBytes(p)).trim) finally stream.close() } @@ -107,13 +108,48 @@ object Compiler { case FTupleTypeDecl(_, i, _, _) => i.value case FTypeAlias(_, i, _, _) => i.value }.toSet - val requires = decls.collect { - case FTraitInstance(_, traitId, _, _, _, _) => - traitId.value - }.toSet + val requires = decls.flatMap(declTypeRefs).toSet -- provides StdlibFileDeps(provides, requires) } + /** Collect every type-name referenced by a declaration's signature, body type + * annotations, and trait/instance heads. Used by `extractStdlibFileDeps` to + * drive topological ordering: a file referencing `List` in any signature + * must load after the file that provides `List`. + */ + def declTypeRefs(decl: FDecl): Set[String] = decl match { + case FTraitInstance(_, traitId, _, typeId, _, methods) => + Set(traitId.value, typeId.value) ++ methods.flatMap(declTypeRefs) + case FFuncDecl(sig, _) => funcSigTypeRefs(sig) + case FMethodDecl(sig, _) => methodSigTypeRefs(sig) + case FTypeFuncDecls(_, _, _, methods) => methods.flatMap(declTypeRefs).toSet + case FTraitDecl(_, _, _, members) => + members.flatMap { + case Left(FMethodDecl(sig, _)) => methodSigTypeRefs(sig) + case Right(sig) => methodSigTypeRefs(sig) + }.toSet + case _ => Set.empty[String] + } + + def funcSigTypeRefs(sig: FFuncSig): Set[String] = { + val paramTypes = sig.p.toList.flatten.map(_.t) + (paramTypes :+ sig.r).flatMap(typeRefs).toSet + } + + def methodSigTypeRefs(sig: FMethodSig): Set[String] = { + val paramTypes = + sig.p.toList.flatMap(_.params.toList.flatten).map(_.t) + (paramTypes :+ sig.r).flatMap(typeRefs).toSet + } + + def typeRefs(t: FType): Set[String] = t match { + case FSimpleType(_, id, args) => + Set(id.value) ++ args.toList.flatten.flatMap(typeRefs) + case FTupleType(_, ts) => ts.flatMap(typeRefs).toSet + case FFuncType(_, ins, out) => (ins :+ out).flatMap(typeRefs).toSet + case FUnitType(_) => Set.empty + } + /** Kahn-style topological sort. Each pass emits every file whose `requires` * is satisfied by files already emitted; if no file is ready and the * frontier is non-empty, the graph has a cycle (error). A `requires` entry diff --git a/src/main/scala/Fuse.scala b/src/main/scala/Fuse.scala index 723ea73..1a31112 100644 --- a/src/main/scala/Fuse.scala +++ b/src/main/scala/Fuse.scala @@ -24,6 +24,11 @@ case class BuildFile(file: String, includeStdlib: Boolean = true) extends Command case class CheckFile(file: String, includeStdlib: Boolean = true) extends Command +case class RunFile( + file: String, + args: List[String] = Nil, + includeStdlib: Boolean = true +) extends Command /** Build pipeline errors. */ sealed trait BuildError @@ -72,12 +77,22 @@ object Fuse fileOpts.map(f => CheckFile(f)) } - val compilerCommand: Opts[Command] = buildCommand `orElse` checkCommand + val runCommand: Opts[RunFile] = + Opts.subcommand("run", "Compile and run a fuse source code file.") { + ( + fileOpts, + Opts.arguments[String](metavar = "args").orEmpty.map(_.toList) + ).mapN((f, as) => RunFile(f, as)) + } + + val compilerCommand: Opts[Command] = + buildCommand `orElse` checkCommand `orElse` runCommand override def main: Opts[IO[ExitCode]] = compilerCommand.map { case c: BuildFile => build(c) case c: CheckFile => check(c) + case c: RunFile => run(c) } /** Build pipeline using EitherT for short-circuit error handling. */ @@ -102,6 +117,52 @@ object Fuse } } + /** Compile and run a Fuse source file, forwarding stdio and exit code. */ + def run(command: RunFile): IO[ExitCode] = + runFile(command.file, command.args, command.includeStdlib, executeInherited) + .flatMap { + case Right(code) => IO.pure(code) + case Left(err) => + IO.println(formatBuildError(err)).as(ExitCode.Error) + } + + /** Shared compile-then-execute pipeline. Builds the Fuse source to a native + * binary, runs the binary via the supplied executor, and removes the `.grin` + * and `.out` intermediates regardless of executor outcome. + */ + def runFile[A]( + file: String, + args: List[String], + includeStdlib: Boolean, + executor: (Path, List[String]) => IO[A] + ): IO[Either[BuildError, A]] = { + val paths = BuildPaths.fromSource(file) + val pipeline: EitherT[IO, BuildError, A] = for { + _ <- compileFuseToGrin(BuildFile(file, includeStdlib), paths) + _ <- compileGrinWithGC(paths) + _ <- EitherT.right[BuildError]( + cleanupIntermediateFiles(List(paths.grin)) + ) + result <- EitherT.right[BuildError]( + executor(paths.output, args) + .guarantee(cleanupIntermediateFiles(List(paths.output))) + ) + } yield result + pipeline.value + } + + /** Executor for `run` that forwards stdin/stdout/stderr to the child and + * returns the child's exit code as the CLI exit code. + */ + def executeInherited(exe: Path, args: List[String]): IO[ExitCode] = + IO.blocking { + val code = new ProcessBuilder((exe.toString +: args)*) + .inheritIO() + .start() + .waitFor() + ExitCode(code) + } + /** Format build error for display. */ def formatBuildError(error: BuildError): String = error match { case FuseCompileError(e) => e.toString diff --git a/src/main/scala/code/Grin.scala b/src/main/scala/code/Grin.scala index e1a8d67..24c6b38 100644 --- a/src/main/scala/code/Grin.scala +++ b/src/main/scala/code/Grin.scala @@ -46,6 +46,33 @@ object Grin { case _ => false } + /** Loophole N — Single Canonical Type Lookup (SCTL). + * + * Single source of truth for an unannotated closure's full arrow type at + * GRIN-gen time. Used by path 1 (`toClosureAbs` None arm) and as a + * last-resort by path 2 (`toClosureValue` cascade). Paths 3 and 4 inherit + * transitively from paths 1 and 2. + * + * Reads from `closureTypesFromBind` (Mono-populated for generic binds) + * first, then `closureTypesGlobal` (cross-bind aggregate with + * name-uniqueness drop). Filters the result to clean arrows — no unsolved + * EVars, no unresolved TypeVars (referent-encoded TypeVars with `Some(name)` + * count as resolved per Loophole K Option 1). + * + * Deliberately skips `closureTypesFallback`: the per-bind map keeps + * colliding names (last-write-wins) for path 2's type-class continuation + * dispatch, but that semantics is unsuitable here — SCTL must return either + * a precise type or None, never an arbitrary collision winner. + */ + def canonicalClosureType(name: String)(implicit env: Env): Option[Type] = + env.closureTypesFromBind + .get(name) + .orElse(env.closureTypesGlobal.get(name)) + .filter(ty => + !TypeChecker.containsUnsolvedEVar(ty) + && !TypeChecker.hasUnresolvedTypeVar(ty) + ) + def isClosureTerm(t: Term): Boolean = t match { case _: TermClosure => true case TermFix(_, c: TermClosure) => true @@ -81,13 +108,13 @@ object Grin { (lambdaBindings, partialFunctions) = values.flatten.unzip applyFunction <- buildApply(partialFunctions.flatten) } yield { - val grinCode = (lambdaBindings.flatten.map(_.show) :+ applyFunction) + val raw = (lambdaBindings.flatten.map(_.show) :+ applyFunction) .mkString("\n\n") .replaceAll( "([a-zA-Z0-9])#", "$1'" ) // Replace # with ' in type specializations only - .replaceAll("[\\[\\]]", "") // Remove brackets from TypeApp names + val grinCode = stripBracketsOutsideStrings(raw) val ffi = generateMissingFFI(grinCode) ffi.isEmpty match { case true => grinCode @@ -97,6 +124,25 @@ object Grin { s.runEmptyA.value } + // Strips `[` and `]` from TypeApp identifier names without touching the + // bytes inside GRIN string literals of the form `#"..."`. The unguarded + // global `replaceAll("[\\[\\]]", "")` previously deleted brackets from + // user string content too. The lit regex tolerates `\"` escapes so it + // stays correct for source strings containing escaped quotes. + def stripBracketsOutsideStrings(code: String): String = { + val grinStringLit = """#"(?:[^"\\]|\\.)*"""".r + val (acc, last) = grinStringLit + .findAllMatchIn(code) + .foldLeft(("", 0)) { case ((s, prev), m) => + ( + s + code.substring(prev, m.start).replaceAll("[\\[\\]]", "") + + m.matched, + m.end + ) + } + acc + code.substring(last).replaceAll("[\\[\\]]", "") + } + /** Build Env by collecting partial functions from bindings and constructing * all closure-related maps. * @@ -110,9 +156,36 @@ object Grin { def buildEnv(bindings: List[Bind]): Env = { val typeInstancesMap = buildTypeInstancesMap(bindings) val typeInstanceMethodsMap = buildTypeInstanceMethodsMap(bindings) + // Aggregate concrete-arrow closure types from all binds' insts. Used + // as a last-resort fallback in `toClosureValue` when both `pureInfer` + // and `getClosureTypeWithFallback` fail to type a closure at GRIN-gen + // — typically when the closure's body references identifiers that + // don't resolve in the post-Mono Context. Each bind contributes its + // own closure-Resolution insts, keyed by closure parameter name. Names + // that occur multiple times across binds are excluded — caller cannot + // disambiguate by name alone. + val closureTypesGlobal: Map[String, Type] = { + val candidates = bindings.flatMap(bind => + bind.insts + .filter(_.r == core.Instantiations.Resolution.Closure) + .flatMap(cInst => + cInst.tys.headOption.collect { + case ty: TypeArrow + if !TypeChecker.containsUnsolvedEVar(ty) + && !TypeChecker.hasUnresolvedTypeVar(ty) => + cInst.i -> ty + } + ) + ) + val nameCounts = candidates.groupBy(_._1).view.mapValues(_.size).toMap + candidates.collect { + case (name, ty) if nameCounts.getOrElse(name, 0) == 1 => name -> ty + }.toMap + } val primingEnv = Env( typeInstances = typeInstancesMap, - typeInstanceMethods = typeInstanceMethodsMap + typeInstanceMethods = typeInstanceMethodsMap, + closureTypesGlobal = closureTypesGlobal ) val primingFunctions = collectPartialFunctions(bindings, primingEnv) val primingEnvWithFunctions = @@ -141,10 +214,10 @@ object Grin { env: Env, partialFunctions: List[PartialFunValue] ): Env = { - val closureMap: Map[String, List[String]] = partialFunctions + val closureMap: Map[ClosureSig, List[String]] = partialFunctions .filter(_.typeKey.nonEmpty) - .groupBy(_.typeKey) - .map { case (typeKey, pfs) => (typeKey, pfs.map(_.f).distinct) } + .groupBy(pf => parseTypeKey(pf.typeKey)) + .map { case (sig, pfs) => (sig, pfs.map(_.f).distinct) } // Per-closure facts store raw parameter and return types. Merging of // unresolved parameter types into concrete sibling groups happens in // buildGroupedApply (via duplication) and in applyFnForClosure (at @@ -160,7 +233,8 @@ object Grin { }.toMap env.copy( closureMap = closureMap, - arityFactsMap = arityFactsMap + arityFactsMap = arityFactsMap, + closureTypesGlobal = env.closureTypesGlobal ) } @@ -343,9 +417,45 @@ object Grin { case true => pickConcreteParam(remaining, returnType).getOrElse(raw) } case None => - siblingParam(remaining, returnType).getOrElse { - val p = extractParamTypeAt(typeKeyFallback, 0) - p.isEmpty match { case true => "unknown"; case false => p } + // Loophole N circumvent: when typeKeyFallback supplies a + // concrete first-parameter type AND a closure is actually + // registered with that paramType in the same + // (arity, returnType) bucket, prefer it over the + // non-deterministic `siblingParam` heuristic. + // `siblingParam` iterates `env.arityFactsMap.values` + // unsorted; its `.headOption` pick can flip when unrelated + // closures register/unregister, poisoning apply keys for + // call sites whose closure type is already known. + // The `concreteRegistered` check guarantees the picked + // paramType corresponds to an apply table that + // `buildGroupedApply` actually generated — without this + // guard, the fix could emit a dangling reference to a + // table that wouldn't exist (e.g. + // `apply1_unit_to_unit` when no concrete-unit-param + // closure for return=unit exists; only the unresolved + // sibling registered under `apply1_TypeEVar_to_unit`). + val raw = extractParamTypeAt(typeKeyFallback, 0) + val concreteRegistered = + (raw.nonEmpty && !isUnresolvedParamType(raw)) match { + case false => false + case true => + env.arityFactsMap.values.exists(af => + af.totalParams >= remaining + && af.returnTypeKey == returnType + && af.paramTypeKeys + .lift(af.totalParams - remaining) + .contains(raw) + ) + } + concreteRegistered match { + case true => raw + case false => + siblingParam(remaining, returnType).getOrElse { + raw.isEmpty match { + case true => "unknown" + case false => raw + } + } } } applyFnName(remaining, paramType, returnType) @@ -394,7 +504,7 @@ object Grin { * was created. */ def resolveTypeConstructors(ty: Type): ContextState[Type] = ty match { - case TypeApp(info, tv @ TypeVar(tvInfo, idx, n), arg) => + case TypeApp(info, tv @ TypeVar(tvInfo, idx, n, _), arg) => for { currentCtxLen <- State.inspect { (ctx: Context) => ctx._1.length } adjustedIdx = currentCtxLen - n + idx @@ -506,14 +616,17 @@ object Grin { ): ContextState[Option[(List[LambdaBinding], List[PartialFunValue])]] = { // Pass closure types from the Bind to GRIN generation. // Also pass resolved closure types from insts as fallback for unannotated closures. + // Use the symmetric guard already used by `closureTypesGlobal` — + // `containsUnsolvedEVar` + `hasUnresolvedTypeVar`. The previous + // `getTypeContextLength` filter dropped legitimate types whose + // constructors are referent-encoded TypeVars (Loophole K Option 1). val closureTypesFallback = binding.insts .filter(_.r == core.Instantiations.Resolution.Closure) .flatMap(cInst => cInst.tys.headOption.collect { case ty: TypeArrow - if !MonoSpecialize - .getTypeContextLength(ty) - .isDefined => + if !TypeChecker.containsUnsolvedEVar(ty) + && !TypeChecker.hasUnresolvedTypeVar(ty) => cInst.i -> ty } ) @@ -1162,7 +1275,7 @@ object Grin { } yield expr } - // Helper: Extract closure name from P-tag (e.g., "P2c18" -> "c18") + // Helper: Extract closure name from P-tag (e.g., "P2foo" -> "foo") def closureFromTag(tag: String): String = tag.dropWhile(c => c == 'P' || c.isDigit) @@ -1200,13 +1313,13 @@ object Grin { arity: Int, tKey: String = "" )(implicit env: Env): AppExpr = { - // Extract closure name from tag (e.g., "P2c18" -> "c18") + // Extract closure name from tag (e.g., "P2foo" -> "foo") val closureName = closureFromTag(tag) // If tKey is a closure name (no arrows), look up its typeKey from closureMap val closureToTypeKey: Map[String, String] = env.closureMap.flatMap { - case (typeKey, closureNames) => - closureNames.map(cn => cn -> typeKey) + case (sig, closureNames) => + closureNames.map(cn => cn -> sig.toTypeKey) } // Use closure name to look up typeKey, which contains the accurate type info @@ -1235,8 +1348,8 @@ object Grin { // For each closure, look up its actual return type from closureMap (typeKey -> closures) // by finding which typeKey contains this closure, then extracting the return type val closureToTypeKey: Map[String, String] = env.closureMap.flatMap { - case (typeKey, closureNames) => - closureNames.map(cn => cn -> typeKey) + case (sig, closureNames) => + closureNames.map(cn => cn -> sig.toTypeKey) } // For each closure, dispatch to the appropriate apply function @@ -1286,7 +1399,10 @@ object Grin { case (name, fact) if fact.totalParams == arity => name }.toList - // Helper: Lookup closures with fallback (no specialized variant expansion) + // Helper: Lookup closures with fallback. closureMap is keyed by + // structural `ClosureSig` (so `parseTypeKey(tKey)` becomes the lookup + // key); query semantics are exact match first, all-arity fallback when + // missing. def lookupClosures(tKey: String, arity: Int)(implicit env: Env ): List[String] = @@ -1294,7 +1410,10 @@ object Grin { tKey.contains(",") match { case true => tKey.split(",").toList case false => - env.closureMap.getOrElse(tKey, getClosuresByArity(arity)) + env.closureMap.getOrElse( + parseTypeKey(tKey), + getClosuresByArity(arity) + ) } // Helper: Lookup single closure (for direct P-tag inference) @@ -1374,9 +1493,21 @@ object Grin { // (avoids De Bruijn index issues when closure has stale indices from type-checking phase) arity = getClosureArity(c) // Compute typeKey from closure type - try multiple sources: - // 1. First try closureTypesFromBind (from monomorphization) - has accurate full arrow types - // 2. Then try type-checking the closure - // 3. Finally fall back to structure extraction + // 1. First try closureTypesFromBind. Two seeding paths populate this: + // (a) `MonoSpecialize.finalizeSpecializedBind` for specialized + // generic binds (legacy seed; carries spec-substituted arrow). + // (b) `TypeChecker.closureTypesFromInsts` for monomorphic binds + // whose closure is identity-shape (`λx. x`). Identity closures + // are exactly the case where pureInfer provably leaks unsolved + // EVars; non-identity closures are deliberately excluded so the + // existing pureInfer path (step 2) keeps handling them with its + // current behavior. + // 2. Then `pureInfer`. Returns EVar-laden arrows for closures the + // initial bidirectional check left under-constrained. Pre-existing + // dispatch tolerates the `"TypeEVar->T"` shape for `λ_. + // concrete_body` style closures whose unused parameter stays + // unsolved, so this path is kept as the second-class fallback. + // 3. Finally structure extraction, then `closureTypesFallback`. closureTypeFromBind = c match { case TermClosure(_, variable, _, _) => env.closureTypesFromBind.get(variable) @@ -1393,10 +1524,20 @@ object Grin { getClosureTypeWithFallback(c).map { case Some(ty) => Some(ty) case None => - // Last resort: check closure types from insts + // Last-resort lookup: per-bind closureTypesFallback (built + // from current binding's own insts) and the global aggregate + // closureTypesGlobal (built across all binds). The global + // map covers the case where the closure's body references + // identifiers that don't resolve in the post-Mono Context — + // pureInfer errors out and the per-bind fallbacks miss + // because the closure has been lambda-lifted out of its + // parent's bind, but the parent bind's insts still carry + // the bidirectionally-checked arrow. c match { case TermClosure(_, variable, _, _) => - env.closureTypesFallback.get(variable) + env.closureTypesFallback + .get(variable) + .orElse(env.closureTypesGlobal.get(variable)) case _ => None } } @@ -1443,11 +1584,19 @@ object Grin { b <- toClosureAbs(body) } yield Abs(toParamVariable(variable2), b) case TermClosure(_, variable, None, body) => - // TODO: Remove this simplification of variable type for closure, - // once all term closures are populated with var types during - // monomorphization phase. This logic is primarly used to test - // building inline lambdas without type annotations. - val variableType = TypeUnit(UnknownInfo) + // Route the unannotated parameter binding through SCTL. + // `canonicalClosureType(variable)` returns `Some(arrow)` when one + // of the trustworthy maps (closureTypesFromBind, + // closureTypesGlobal) has a clean arrow type for this closure. + // Extract the parameter type from the arrow's left side; fall + // back to TypeUnit when SCTL misses (preserves baseline for + // unrelated closures). Safe because `applyFnForClosure` is + // deterministic (typeKeyFallback takes priority over + // `siblingParam` when concrete and registered). + val variableType: Type = canonicalClosureType(variable) match { + case Some(TypeArrow(_, paramTy, _)) => paramTy + case _ => TypeUnit(UnknownInfo) + } for { variable1 <- includeFunctionSuffix(variable, variableType) variable2 <- Context.addBinding(variable1, VarBind(variableType)) diff --git a/src/main/scala/code/GrinPrelude.scala b/src/main/scala/code/GrinPrelude.scala index 896239f..88b6d23 100644 --- a/src/main/scala/code/GrinPrelude.scala +++ b/src/main/scala/code/GrinPrelude.scala @@ -6,14 +6,16 @@ object GrinPrelude { ("_prim_int_print", "_prim_int_print :: T_Int64 -> T_Int64"), ("_prim_usleep", "_prim_usleep :: T_Int64 -> T_Int64"), ("_prim_string_print", "_prim_string_print :: T_String -> T_Int64"), - ("_prim_read_string", "_prim_read_string :: T_String"), + ("_prim_read_string", "_prim_read_string :: T_Int64 -> T_String"), ("_prim_error", "_prim_error :: T_String -> T_Int64"), ("_prim_ffi_file_eof", "_prim_ffi_file_eof :: T_Int64 -> T_Int64"), ("_prim_file_read", "_prim_file_read :: T_String -> T_String"), ( "_prim_file_write", "_prim_file_write :: T_String -> T_String -> T_Int64" - ) + ), + ("_prim_args_count", "_prim_args_count :: T_Int64 -> T_Int64"), + ("_prim_args_get", "_prim_args_get :: T_Int64 -> T_String") ) val ffiPure: List[(String, String)] = List( @@ -45,7 +47,15 @@ object GrinPrelude { ("_prim_str_int", "_prim_str_int :: T_String -> T_Int64"), ("_prim_int_float", "_prim_int_float :: T_Int64 -> T_Float"), ("_prim_float_string", "_prim_float_string :: T_Float -> T_String"), - ("_prim_char_int", "_prim_char_int :: T_Char -> T_Int64") + ("_prim_char_int", "_prim_char_int :: T_Char -> T_Int64"), + ( + "_prim_string_char_at", + "_prim_string_char_at :: T_String -> T_Int64 -> T_Int64" + ), + ( + "_prim_string_substring", + "_prim_string_substring :: T_String -> T_Int64 -> T_Int64 -> T_String" + ) ) val primopPure: List[(String, String)] = List( diff --git a/src/main/scala/code/GrinUtils.scala b/src/main/scala/code/GrinUtils.scala index a55f13e..de93849 100644 --- a/src/main/scala/code/GrinUtils.scala +++ b/src/main/scala/code/GrinUtils.scala @@ -106,11 +106,14 @@ object GrinUtils { val constructorKey = typeToKey(constructor) val argKey = typeToKey(arg) s"$constructorKey[$argKey]" - case TypeVar(_, idx, _) => - // TypeVars are problematic - just use a generic placeholder based on index - // This will likely cause mismatches, but better than crashing - s"T$idx" - case TypeId(_, name) => + case TypeVar(_, idx, _, Some(name)) => + // Loophole K Option 1: prefer the referent name (a stable + // type-constructor or generic-param identifier) over the + // De-Bruijn idx — idx drifts across mono epochs and can stamp + // unrelated names into apply-fn group keys. + name + case TypeVar(_, idx, _, _) => s"T$idx" + case TypeId(_, name) => name case TypeInt(_) => "i32" case TypeFloat(_) => "f32" @@ -198,6 +201,25 @@ object GrinUtils { } ._2 + /** Parse a typeKey string into a structural `ClosureSig`. Splits on top-level + * `->` (nested arrows inside `[...]` are preserved as atomic). For non-arrow + * keys (no top-level `->`), returns an empty paramKeys list with the whole + * key as the return key. + */ + def parseTypeKey(typeKey: String): ClosureSig = { + val arrows = topLevelArrowIndices(typeKey) + arrows.isEmpty match { + case true => ClosureSig(Nil, typeKey) + case false => + val paramKeys = + (0 until arrows.length).toList.map(i => + extractParamTypeAt(typeKey, i) + ) + val returnKey = typeKey.substring(arrows.last + 2) + ClosureSig(paramKeys, returnKey) + } + } + /** Get closure arity without type-checking (avoids De Bruijn index issues) */ def getClosureArity(c: Term): Int = c match { case TermClosure(_, _, _, body) => 1 + getClosureArity(body) @@ -331,14 +353,14 @@ object GrinUtils { } def typeArgToString(ty: Type): ContextState[String] = ty match { - case TypeInt(_) => State.pure("i32") - case TypeFloat(_) => State.pure("f32") - case TypeString(_) => State.pure("str") - case TypeBool(_) => State.pure("bool") - case TypeUnit(_) => State.pure("Unit") - case TypeId(_, name) => State.pure(name) - case TypeVar(_, idx, _) => getNameFromIndex(idx) - case TypeApp(_, t1, t2) => + case TypeInt(_) => State.pure("i32") + case TypeFloat(_) => State.pure("f32") + case TypeString(_) => State.pure("str") + case TypeBool(_) => State.pure("bool") + case TypeUnit(_) => State.pure("Unit") + case TypeId(_, name) => State.pure(name) + case TypeVar(_, idx, _, _) => getNameFromIndex(idx) + case TypeApp(_, t1, t2) => for { s1 <- typeArgToString(t1) s2 <- typeArgToString(t2) @@ -401,6 +423,43 @@ object GrinUtils { }) ) } yield result + // Loophole L Option B (extended): non-specialized method projection + // (`exec`, `map`, …) on a TermVar receiver. The cached generic + // `!{m}#{IO}` impl method is *gone* after monomorphization (mono + // replaces it with type-arg-suffixed specs like `!exec#IO#i32`). + // Look up the type-arg-suffixed impl method using the receiver's + // concrete type-arg name; if found, use its TermAbbBind.ty (already + // post-substitution at mono time) to skip `typeCheck(t1)`. + case TermMethodProj(_, TermVar(_, selfIdx, _), m) + if !fuse.SpecializedMethodUtils.isSpecializedMethod(m) => + for { + selfTypeOpt <- toContextStateOption( + Context.getType(UnknownInfo, selfIdx) + ) + result <- selfTypeOpt.flatMap(typeConstructorNameDirect) match { + case None => State.pure[Context, Option[Type]](None) + case Some(typeName) => + val baseMethodID = Desugar.toMethodID(m, typeName) + val argSuffix = selfTypeOpt + .flatMap(typeAppArg) + .flatMap(typeArgSuffixName) + val methodID = + argSuffix.fold(baseMethodID)(suf => s"$baseMethodID#$suf") + for { + methodIdxOpt <- State + .inspect[Context, Option[Int]](nameToIndex(_, methodID)) + methodTyOpt <- methodIdxOpt match { + case None => State.pure[Context, Option[Type]](None) + case Some(idx) => + toContextStateOption(getBinding(UnknownInfo, idx)) + .map(_.flatMap { + case TermAbbBind(_, Some(ty)) => Some(ty) + case _ => None + }) + } + } yield methodTyOpt + } + } yield result case TermVar(_, idx, _) => // Handle specialized data constructors (e.g., Cons#i32, Nil#i32) for { @@ -418,57 +477,37 @@ object GrinUtils { case _ => State.pure(None) } - // CallCategory: Categorize function calls following GHC-GRIN pattern - // This enables explicit handling of saturated, under-saturated, and over-saturated calls - sealed trait CallCategory - case class SaturatedCall(fn: String, args: List[String]) extends CallCategory - case class UndersaturatedCall(fn: String, args: List[String], remaining: Int) - extends CallCategory - case class OversaturatedCall( - fn: String, - initialArgs: List[String], - extraArgs: List[String] - ) extends CallCategory - - /** Categorize a function call based on arity facts. - * @param fn - * Function name - * @param args - * Arguments provided - * @param expectedArity - * Expected number of arguments (from ArityFact) - * @return - * CallCategory indicating saturated, undersaturated, or oversaturated - */ - def categorizeCall( - fn: String, - args: List[String], - expectedArity: Int - ): CallCategory = { - val providedCount = args.length - providedCount.compare(expectedArity) match { - case 0 => SaturatedCall(fn, args) - case x if x < 0 => - UndersaturatedCall(fn, args, expectedArity - providedCount) - case _ => - OversaturatedCall( - fn, - args.take(expectedArity), - args.drop(expectedArity) - ) - } + def typeConstructorNameDirect(ty: Type): Option[String] = ty match { + case TypeApp(_, TypeVar(_, _, _, Some(name)), _) => Some(name) + case TypeApp(_, TypeId(_, name), _) => Some(name) + case _ => None + } + + def typeAppArg(ty: Type): Option[Type] = ty match { + case TypeApp(_, _, arg) => Some(arg) + case _ => None } - /** Categorize a closure call using ArityFact. For closures, expectedArity = - * totalParams - capturedVars (dispatch arity) + /** Render a type as its mono spec-name suffix component (e.g. `i32`, `str`, + * `Unit`, `List[str]`). Mirrors how `MonoSpecialize` encodes type arguments + * into bind names. Returns `None` for non-canonical shapes (e.g. open + * `TypeVar` without referent or type-app without a constructor name) so + * callers fall through. */ - def categorizeClosureCall( - fn: String, - args: List[String], - fact: ArityFact - ): CallCategory = { - val dispatchArity = fact.totalParams - fact.capturedVars - categorizeCall(fn, args, dispatchArity) + def typeArgSuffixName(ty: Type): Option[String] = ty match { + case TypeInt(_) => Some("i32") + case TypeFloat(_) => Some("f32") + case TypeBool(_) => Some("bool") + case TypeString(_) => Some("str") + case TypeUnit(_) => Some("Unit") + case TypeId(_, name) => Some(name) + case TypeVar(_, _, _, Some(name)) => Some(name) + case TypeApp(_, head, arg) => + for { + h <- typeArgSuffixName(head) + a <- typeArgSuffixName(arg) + } yield s"$h[$a]" + case _ => None } object PrimOp: diff --git a/src/main/scala/code/MonoDriver.scala b/src/main/scala/code/MonoDriver.scala index f832871..3eb857e 100644 --- a/src/main/scala/code/MonoDriver.scala +++ b/src/main/scala/code/MonoDriver.scala @@ -61,7 +61,7 @@ object MonoDriver { val newAllBindNames = binds.map(_.i).toSet ++ newSavedGenericBinds.keySet for { - specializedBinds <- toSpecializedBinds( + (specializedBinds, shiftedGenericBinds) <- toSpecializedBinds( binds, resolvableInsts, newAllBindNames @@ -70,9 +70,29 @@ object MonoDriver { specializedBinds, newSavedGenericBinds ) + // `createMissingDataConstrSpecs` inserts new data-constructor specs + // into the bind list and shifts every bind after the insertion via + // `shiftBindAfterInsert`. The `shiftedGenericBinds` map produced by + // `toSpecializedBinds` predates these insertions — its entries are + // stale copies whose stored TypeVar `n` stamps lag the post-insert + // ctx length by the number of data-constructor specs inserted before + // the entry's source position. Refresh the map by re-locating each + // generic in the NEW bind list, so deferred specs in + // `createMissingGenericSpecs` are built from a generic whose body has + // ALL pre-insert and data-constructor-insert shifts applied. Without + // this, `typeShiftOnContextDiff` at GRIN time over-compensates by the + // missing data-constructor delta, drifting e.g. `List` references in + // a Nil-armed bind's transitive spec to `read`. + refreshedShiftedGenericBinds = shiftedGenericBinds.map { + case (name, staleBind) => + name -> bindsWithDataConstrSpecs + .find(_.i == name) + .getOrElse(staleBind) + } bindsWithDeferredSpecs <- createMissingGenericSpecs( bindsWithDataConstrSpecs, - newSavedGenericBinds + newSavedGenericBinds, + refreshedShiftedGenericBinds ) modifiedBinds <- bindsWithDeferredSpecs.traverse( replaceInstantiations(_, bindsWithDeferredSpecs) diff --git a/src/main/scala/code/MonoRewrite.scala b/src/main/scala/code/MonoRewrite.scala index a2f22b2..bb435b8 100644 --- a/src/main/scala/code/MonoRewrite.scala +++ b/src/main/scala/code/MonoRewrite.scala @@ -30,22 +30,46 @@ object MonoRewrite { /** Insert specialized binds into the bind list. Iterates over all binds and * creates specializations for those that have matching instantiations. * Accumulates De Bruijn index shifts as new binds are inserted. + * + * Returns the specialized bind list and a map from generic-bind name to the + * shifted source bind. The map is populated for BOTH the i-arm (binds + * specialised in this pass) and the Nil-arm (generic binds whose bodies are + * shifted by the accumulator but not specialised here) — + * `createMissingGenericSpecs` consumes the map to build deferred specs from + * the iter-shifted source rather than the unshifted `savedGenericBinds`. + * Without the Nil-arm coverage (Loophole D), a deferred spec built from an + * unshifted source produces TypeVars whose `n` is δ units below the GRIN + * context length, drifting names by δ slots at lookup time. */ def toSpecializedBinds( binds: List[Bind], insts: List[Instantiation], allBindNames: Set[String] = Set.empty - ): ContextState[List[Bind]] = + ): ContextState[(List[Bind], Map[String, Bind])] = for { result <- binds - .foldLeftM((List[Bind](), List[Shift]())) { - case ((binds, shifts), bind) => + .foldLeftM((List[Bind](), List[Shift](), Map[String, Bind]())) { + case ((binds, shifts, shiftedBinds), bind) => getBindInstantiations(bind, insts).flatMap { bindInsts => bindInsts match { case Nil => val sbind = shifts.foldLeft(bind) { (b, s) => bindShift(s.d, b, s.c) } + // Record the shifted version of every GENERIC bind whose + // body gets bumped by the iteration-local accumulator + // even when it is not itself specialised in this pass. + // `createMissingGenericSpecs` later uses this map to + // build deferred specs from the iter-shifted source bind + // rather than the unshifted `savedGenericBinds` — without + // this, the deferred spec's body retains pre-shift indices + // while the bind list around it grew, producing the + // cross-iteration drift class (Loophole D). + val updatedShiftedBinds = bind.b match { + case TermAbbBind(_: TermTAbs, _) => + shiftedBinds + (bind.i -> sbind) + case _ => shiftedBinds + } for { _ <- addBinding(sbind.i, sbind.b) sInsts <- sbind.insts @@ -56,14 +80,17 @@ object MonoRewrite { sInsts, sbind.closureTypes ) - } yield (b, incrShifts(shifts)) + } yield (b, incrShifts(shifts), updatedShiftedBinds) case i => + val shiftedSourceBind = shifts.foldLeft(bind) { (b, s) => + bindShift(s.d, b, s.c) + } for { specializedBindsList <- i.zipWithIndex .traverse((inst, idx) => for { bind <- buildSpecializedBind( - bind, + shiftedSourceBind, inst, idx, allBindNames @@ -76,14 +103,15 @@ object MonoRewrite { binds ::: specializedBindsList, incrShifts(shifts) :+ Shift( specializedBindsList.length - 1, - 0 - ) + 1 + ), + shiftedBinds + (bind.i -> shiftedSourceBind) ) } } } - (specializedBinds, _) = result - } yield specializedBinds + (specializedBinds, _, shiftedGenericBinds) = result + } yield (specializedBinds, shiftedGenericBinds) /** Find instantiations from the given list that match a specific bind. For * ADTs, all instantiations for that type are associated. For type instance @@ -198,25 +226,71 @@ object MonoRewrite { case _ => false } + /** Create missing specializations for generic functions discovered via + * transitive instantiation resolution (i.e., after the main + * `toSpecializedBinds` pass has run). Each missing spec is inserted right + * after the last existing specialization of the same function. + * + * The `shiftedGenericBinds` map (produced by `toSpecializedBinds`) contains + * the SHIFTED version of each expanded generic bind — i.e., the version with + * accumulated De Bruijn index shifts applied. These shifted versions carry + * TypeVar `n` (context-length) fields that are consistent with the position + * those binds occupied in the context during the main specialization pass. + * Using the shifted version here is critical: the unshifted original from + * `savedGenericBinds` has TypeVar `n` values that are one or more units too + * small, causing `typeShiftOnContextDiff` in the GRIN backend to compute a + * wrong index offset and resolve type-name references to the wrong binding. + */ def createMissingGenericSpecs( binds: List[Bind], - savedGenericBinds: Map[String, Bind] - ): ContextState[List[Bind]] = { - val unresolvedConcreteInsts = binds.flatMap(b => - b.insts.filter(i => - i.r == Resolution.Unresolved && - !i.tys.exists(ty => getTypeContextLength(ty).isDefined) && - i.tys.exists(isConcreteType(_)) + savedGenericBinds: Map[String, Bind], + shiftedGenericBinds: Map[String, Bind] = Map.empty + ): ContextState[List[Bind]] = for { + // Bind-aware unresolvedness: keep insts whose remaining `TypeVar`s + // refer to top-level concrete types (TypeAbbBind), drop those that + // refer to still-unresolved generic params (TypeVarBind). See + // `MonoSpecialize.isFreeOfGenericTypeVar` for the rationale. + unresolvedConcreteInsts <- binds + .flatMap(_.insts) + .filterA(i => + i.r match { + case Resolution.Unresolved => + i.tys.forallM(ty => isFreeOfGenericTypeVar(ty)).map { freeOfGen => + freeOfGen && i.tys.exists(isConcreteType(_)) + } + case _ => false.pure[ContextState] + } ) - ) - val uniqueInsts = Instantiations.distinct(unresolvedConcreteInsts) - uniqueInsts.foldLeftM(binds) { (currentBinds, inst) => + uniqueInsts = Instantiations.distinct(unresolvedConcreteInsts) + result <- uniqueInsts.foldLeftM(binds) { (currentBinds, inst) => for { specName <- toContextState(inst.bindName()) exists <- State.inspect { (ctx: Context) => nameToIndex(ctx, specName).isDefined } - result <- (exists, savedGenericBinds.get(inst.i)) match { + // Each iteration of the foldLeft inserts a spec into `currentBinds` + // and shifts every bind after the insertion via `shiftBindAfterInsert`. + // The `shiftedGenericBinds` map's values are static snapshots taken + // BEFORE this loop began, so when a later iteration looks up the same + // (or another) generic, the snapshot lacks the prior iteration's + // insertion shift. Re-locate the generic in `currentBinds` (filtered to + // `TermAbbBind(TermTAbs, _)` so we don't accidentally match a non- + // generic same-named bind like a record constructor) so the newly-built + // spec body inherits all accumulated shifts. + currentGenericBind = currentBinds.find(b => + b.i == inst.i && (b.b match { + case TermAbbBind(_: TermTAbs, _) => true + case _ => false + }) + ) + result <- ( + exists, + currentGenericBind.orElse( + shiftedGenericBinds + .get(inst.i) + .orElse(savedGenericBinds.get(inst.i)) + ) + ) match { case (false, Some(genericBind)) => for { newBind <- buildSpecializedBind(genericBind, inst, 0) @@ -231,7 +305,7 @@ object MonoRewrite { } } yield result } - } + } yield result /** Find the insertion position for a new generic specialization. Returns the * index right after the last existing specialization of this function, or @@ -282,36 +356,36 @@ object MonoRewrite { bind: Bind, dataConstrInst: Instantiation ): Bind = { - val matchingInst = bind.insts.find(i => - i.r == Resolution.DeferredDataConstr && i.i == dataConstrInst.i && i.tys == dataConstrInst.tys - ) - matchingInst match { - case None => bind - case Some(inst) => - val termInfo = inst.term match { - case tv: TermVar => Some(tv.info) - case _ => None - } - termInfo match { - case None => bind - case Some(info) => - val redirectedBinding = bind.b match { - case TermAbbBind(term, ty) => - val newTerm = TermFold.findVarByInfo(term, info) match { - case Some((rawIdx, _, depth)) => - val baseIdx = rawIdx - depth - redirectTermVar(term, info, baseIdx, baseIdx - 1) - case None => term - } - TermAbbBind(newTerm, ty) - case _ => bind.b - } - val updatedInsts = bind.insts.filterNot(i => - i.r == Resolution.DeferredDataConstr && i.i == dataConstrInst.i && i.tys == dataConstrInst.tys - ) - Bind(bind.i, redirectedBinding, updatedInsts, bind.closureTypes) - } + def isMatching(i: Instantiation): Boolean = + i.r == Resolution.DeferredDataConstr && + i.i == dataConstrInst.i && + i.tys == dataConstrInst.tys + val rewritten = for { + inst <- bind.insts.find(isMatching) + info <- inst.term match { + case tv: TermVar => Some(tv.info) + case _ => None + } + } yield { + val redirectedBinding = bind.b match { + case TermAbbBind(term, ty) => + val newTerm = TermFold.findVarByInfo(term, info) match { + case Some((rawIdx, _, depth)) => + val baseIdx = rawIdx - depth + redirectTermVar(term, info, baseIdx, baseIdx - 1) + case None => term + } + TermAbbBind(newTerm, ty) + case other => other + } + Bind( + bind.i, + redirectedBinding, + bind.insts.filterNot(isMatching), + bind.closureTypes + ) } + rewritten.getOrElse(bind) } def redirectTermVar( @@ -321,11 +395,13 @@ object MonoRewrite { newIdx: Int ): Term = termMap( - (info, c, k, n) => - (info == targetInfo && k == oldIdx + c) match { - case true => TermVar(info, newIdx + c, n) - case false => TermVar(info, k, n) - }, + (info, c, k, n) => { + val resolvedK = (info == targetInfo && k == oldIdx + c) match { + case true => newIdx + c + case false => k + } + TermVar(info, resolvedK, n) + }, (c, ty) => ty, 0, term diff --git a/src/main/scala/code/MonoSpecialize.scala b/src/main/scala/code/MonoSpecialize.scala index 4c5dae7..5e6303c 100644 --- a/src/main/scala/code/MonoSpecialize.scala +++ b/src/main/scala/code/MonoSpecialize.scala @@ -71,8 +71,8 @@ object MonoSpecialize { } def hasNegativeTypeVarIndex(ty: Type): Boolean = ty match { - case TypeVar(_, idx, _) => idx < 0 - case TypeApp(_, t1, t2) => + case TypeVar(_, idx, _, _) => idx < 0 + case TypeApp(_, t1, t2) => hasNegativeTypeVarIndex(t1) || hasNegativeTypeVarIndex(t2) case TypeArrow(_, t1, t2) => hasNegativeTypeVarIndex(t1) || hasNegativeTypeVarIndex(t2) @@ -81,8 +81,8 @@ object MonoSpecialize { } def getTypeContextLength(ty: Type): Option[Int] = ty match { - case TypeVar(_, _, n) => Some(n) - case TypeApp(_, t1, t2) => + case TypeVar(_, _, n, _) => Some(n) + case TypeApp(_, t1, t2) => getTypeContextLength(t1).orElse(getTypeContextLength(t2)) case TypeArrow(_, t1, t2) => getTypeContextLength(t1).orElse(getTypeContextLength(t2)) @@ -95,7 +95,7 @@ object MonoSpecialize { ty: Type, numBinds: Int ): ContextState[Boolean] = ty match { - case TypeVar(info, index, ctxLen) => + case TypeVar(info, index, ctxLen, _) => (index < 0) match { case true => false.pure[ContextState] case false => @@ -116,8 +116,8 @@ object MonoSpecialize { } def extractTypeVarIndices(ty: Type): List[Int] = ty match { - case TypeVar(_, idx, _) => List(idx) - case TypeApp(_, t1, t2) => + case TypeVar(_, idx, _, _) => List(idx) + case TypeApp(_, t1, t2) => extractTypeVarIndices(t1) ::: extractTypeVarIndices(t2) case TypeArrow(_, t1, t2) => extractTypeVarIndices(t1) ::: extractTypeVarIndices(t2) @@ -224,7 +224,7 @@ object MonoSpecialize { ty: Type, inCtor: Boolean = false ): List[((Int, Int), Boolean)] = ty match { - case TypeVar(_, idx, n) => + case TypeVar(_, idx, n, _) => List(((idx.intValue, n.intValue), inCtor)) case TypeApp(_, t1, t2) => collectTypeParamPairs(t1, true) ::: collectTypeParamPairs(t2, false) @@ -240,7 +240,7 @@ object MonoSpecialize { ty: Type, mapping: Map[Int, Type] ): Type = ty match { - case TypeVar(_, idx, _) => + case TypeVar(_, idx, _, _) => mapping.getOrElse(idx.intValue, ty) case TypeApp(info, t1, t2) => TypeApp( @@ -486,9 +486,13 @@ object MonoSpecialize { term, specializedClosureInsts ) - specializedTerm = shiftedInst.tys.foldLeft( - termWithClosureTypes: Term - )(specializeTerm(_, idx, _)) + specializedTerm = renameSelfRecursiveBinding( + shiftedInst.tys.foldLeft(termWithClosureTypes: Term)( + specializeTerm(_, idx, _) + ), + bind.i, + name + ) deferredDataConstrInsts = bind.insts.filter( _.r == Resolution.DeferredDataConstr ) @@ -535,20 +539,64 @@ object MonoSpecialize { insts, shiftedInst.tys ) - } yield finalizeSpecializedBind( - name, - binding, - insts ::: syntheticDataConstrInsts, - bind.i, - inst.tys, - specializedClosureInsts - ) + finalized <- finalizeSpecializedBind( + name, + binding, + insts ::: syntheticDataConstrInsts, + bind.i, + inst.tys, + specializedClosureInsts + ) + } yield finalized case _ => throw new RuntimeException( s"can't build specialized binding ${inst.i}" ) } + /** Rename the recursion-combinator parameter inside the specialized term so + * its self-references resolve to the specialized name at GRIN codegen time. + * + * Top-level generic `fun` declarations are desugared (see + * `Desugar.withFixCombinator`) to `TermFix(TermAbs(prefix+name, ty, body))`, + * where `prefix+name` is the synthetic recursion-combinator parameter. + * Self-calls inside `body` are TermVars pointing to that binding. After + * monomorphization, the outer `Bind.i` becomes the type-suffixed spec name + * but the inner TermAbs binding name is unchanged. Without renaming, + * `Grin.toVariable` would resolve the recursive TermVar to the original + * prefixed name, strip the prefix, and emit the un-suffixed function name — + * which has no GRIN definition, breaking the linker. + * + * Renaming the inner binding to the prefixed spec name makes + * `Grin.toVariable` emit the spec name, matching the function header emitted + * from `Bind.i` after the standard `#`→`'` post-processing. + * + * `Grin.toParamVariable` already returns `""` for any prefix-marked name, so + * the recursion parameter remains invisible in the function header + * regardless of the rename. + * + * Methods inside `impl` blocks self-call via qualified names which produce + * `TermAssocProj` and are covered by the `containsAssocProjInBinding` path + * in `finalizeSpecializedBind`; this rename specifically handles bare-name + * self-calls in top-level `fun` bodies. + */ + def renameSelfRecursiveBinding( + term: Term, + originalBindName: String, + specializedBindName: String + ): Term = { + val oldRecName = + s"${Desugar.RecursiveFunctionParamPrefix}$originalBindName" + val newRecName = + s"${Desugar.RecursiveFunctionParamPrefix}$specializedBindName" + term match { + case TermFix(info, TermAbs(absInfo, name, ty, body, retTy)) + if name == oldRecName => + TermFix(info, TermAbs(absInfo, newRecName, ty, body, retTy)) + case _ => term + } + } + /** Synthesize insts for bare TermVars in the specialized body that resolve * (via ctx lookup) to an already-specialized data-constructor bind but have * no existing `Instantiation` covering them. @@ -627,36 +675,98 @@ object MonoSpecialize { originalBindName: String, instTys: List[Type], specializedClosureInsts: List[Instantiation] - ): Bind = { - val filteredInsts = insts.filter(inst => - !inst.tys.exists(ty => getTypeContextLength(ty).isDefined) + ): ContextState[Bind] = for { + // Drop insts whose tys still carry a generic-parameter TypeVar + // (binding = TypeVarBind). TypeVars whose binding is a TypeAbbBind + // (a concrete top-level type constructor) are real references and must + // be kept. The previous `!getTypeContextLength(ty).isDefined` + // predicate dropped both indiscriminately and stripped legitimate + // insts whose type arguments contained nested concrete constructors, + // leaving the corresponding spec ungenerated. + filteredInsts <- insts.filterA(inst => + inst.tys.forallM(ty => isFreeOfGenericTypeVar(ty)) ) - // Preserve multiple call sites (different term.info) for the same - // method+types — each one needs its own replacement in - // replaceInstantiations (which matches by info). - val dedupedInsts = Instantiations.distinct(filteredInsts, byTerm = true) - val selfInstantiation = - containsAssocProjInBinding(binding, originalBindName) match { - case true => - List( - Instantiation( - originalBindName, - TermAssocProj(UnknownInfo, instTys.head, originalBindName), - instTys, - List(), - Resolution.Resolved(0) - ) + dedupedInsts = Instantiations.distinct(filteredInsts, byTerm = true) + selfInstantiation = containsAssocProjInBinding( + binding, + originalBindName + ) match { + case true => + List( + Instantiation( + originalBindName, + TermAssocProj(UnknownInfo, instTys.head, originalBindName), + instTys, + List(), + Resolution.Resolved(0) ) - case false => List() - } - val finalInsts = (selfInstantiation ::: dedupedInsts) + ) + case false => List() + } + finalInsts = (selfInstantiation ::: dedupedInsts) .distinctBy(inst => (inst.i, inst.tys, inst.term)) - val closureTypesMap = specializedClosureInsts + closureTypesMap = specializedClosureInsts .flatMap(cInst => cInst.tys.headOption.map(ty => cInst.i -> ty)) .toMap - Bind(name, binding, finalInsts, closureTypesMap) + } yield Bind(name, binding, finalInsts, closureTypesMap) + + /** True iff `ty` contains no `TypeVar` whose binding is a `TypeVarBind` (a + * generic parameter). `TypeVar`s whose binding is a top-level `TypeAbbBind` + * (a concrete top-level type constructor) are kept. Drifted `TypeVar`s + * pointing at a non-type binding (`TermAbbBind`, `VarBind`, etc.) are + * treated as non-concrete and thus filtered — accepting them causes + * `bindName()` to render method/term names into spec suffixes, producing + * pathological recursive spec names. + */ + def isFreeOfGenericTypeVar(ty: Type): ContextState[Boolean] = ty match { + case TypeVar(info, idx, _, referent) => + // Loophole K Option 1: the `referent` carries the original + // type-name. Prefer name-based lookup (via the current ctx) + // over the cached `idx` — the idx is a snapshot from when the + // type was created, and mono routinely adds binds that shift + // it, causing a name to resolve to the wrong (term) binding + // via stale idx. + referent match { + case Some(name) => + State.inspect[Context, Option[Int]](nameToIndex(_, name)).flatMap { + case None => + // Name not in ctx → fall back to idx-based lookup so a + // not-yet-added type still resolves through its captured + // idx if that idx happens to land on a TypeAbbBind. + isFreeViaIdx(info, idx) + case Some(curIdx) => + GrinUtils.toContextStateOption(getBinding(info, curIdx)).map { + case Some(_: TypeAbbBind) => true + case _ => false + } + } + case None => isFreeViaIdx(info, idx) + } + case TypeApp(_, t1, t2) => + isFreeOfGenericTypeVar(t1).flatMap { + case false => false.pure[ContextState] + case true => isFreeOfGenericTypeVar(t2) + } + case TypeArrow(_, t1, t2) => + isFreeOfGenericTypeVar(t1).flatMap { + case false => false.pure[ContextState] + case true => isFreeOfGenericTypeVar(t2) + } + case TypeAll(_, _, _, _, t) => isFreeOfGenericTypeVar(t) + case _: TypeEVar => false.pure[ContextState] + case _ => true.pure[ContextState] } + def isFreeViaIdx(info: parser.Info.Info, idx: Int): ContextState[Boolean] = + idx < 0 match { + case true => false.pure[ContextState] + case false => + GrinUtils.toContextStateOption(getBinding(info, idx)).map { + case Some(_: TypeAbbBind) => true + case _ => false + } + } + def containsAssocProjInBinding( binding: Binding, methodName: String diff --git a/src/main/scala/code/Syntax.scala b/src/main/scala/code/Syntax.scala index 9371857..17b4956 100644 --- a/src/main/scala/code/Syntax.scala +++ b/src/main/scala/code/Syntax.scala @@ -21,7 +21,7 @@ object Syntax { f: String, arity: Int, ptag: Option[String] = - None, // Track which P-tag this variable holds (e.g., "P2c18") + None, // Track which P-tag this variable holds (e.g., "P2foo") typeKey: String = "" // Type signature for closureMap lookup ) extends Expr case class PartialFunValue( @@ -66,14 +66,35 @@ object Syntax { paramTypeKeys: List[String] // Per-level param types for grouping ) + /** Structural signature for closure dispatch. Replaces the prior string-keyed + * `closureMap` so polymorphic queries (e.g. `A->B`) can unify with concrete + * entries (e.g. `i32->i32`) via `unifiesWith`, and concrete queries can + * still match exactly. The string form (recoverable via `toTypeKey`) is + * preserved for downstream consumers (e.g. `applyFnForClosure`, + * `closureToTypeKey`) that still take typeKey strings. + */ + case class ClosureSig(paramKeys: List[String], returnKey: String) { + def toTypeKey: String = paramKeys match { + case Nil => returnKey + case keys => (keys :+ returnKey).mkString("->") + } + } + /** Environment containing all closure-related maps for GRIN code generation. * Used as a single implicit parameter instead of multiple scattered maps. */ case class Env( - closureMap: Map[String, List[String]] = Map.empty, + closureMap: Map[ClosureSig, List[String]] = Map.empty, arityFactsMap: Map[String, ArityFact] = Map.empty, closureTypesFromBind: Map[String, Type] = Map.empty, closureTypesFallback: Map[String, Type] = Map.empty, + // Global last-resort map: aggregates concrete-arrow closure-Resolution + // insts across all binds. Used when `pureInfer(c)`, + // `getClosureTypeWithFallback(c)`, and `closureTypesFallback` all fail + // to produce a closure's type at GRIN-gen — typically when a closure + // body references identifiers that don't resolve in the post-Mono + // Context, leaving its `typeKey` empty without this seed. + closureTypesGlobal: Map[String, Type] = Map.empty, // typeInstances maps typeName -> set of trait class names it impls. // Built once over the full bind list so forward references work during // sequential rendering (trait default method specializations can call @@ -184,6 +205,13 @@ object Syntax { case args => s"pure ($tag ${args.mkString(" ")})" } show"$variable =\n${indent(1, body)}" + // A bare nullary FunctionValue body (e.g. `empty'i32 _76 = MyNone'i32`) + // emits the function call directly with proper indentation. Without + // this branch the value falls through to `PureExpr`, whose `case _` + // calls `expr.show` and drops the leading indent — producing GRIN + // input that `grin` rejects with "incorrect indentation". + case FunctionValue(f, 0, _, _) => + show"$variable =\n${indent(1, f)}" case _ => show"$variable =\n${PureExpr(e)}" } diff --git a/src/main/scala/core/Context.scala b/src/main/scala/core/Context.scala index 002a952..3fbc644 100644 --- a/src/main/scala/core/Context.scala +++ b/src/main/scala/core/Context.scala @@ -151,7 +151,7 @@ object Context { node: String ): ContextState[Option[TypeVar]] = findAlgebraicDataType(info, node).map( - _.map(v => TypeVar(info, v._2, v._3)) + _.map(v => TypeVar(info, v._2, v._3, Some(v._1._1))) ) def getAlgebraicDataTypeName( @@ -199,12 +199,35 @@ object Context { * Required to do so when an existential variables gets resolved _deep_ in * the context but doesn't get shifted once _out_ of it. Usually that happens * in closures. + * + * Loophole K Option 1: TypeVars stamped with `referent = Some(name)` use + * **name-based recovery** instead of uniform shift. The uniform shift is + * correct for closure entry/exit (referent didn't move, ctx grew/shrunk + * uniformly) but over-shifts when mid-list inserts in the bind list displace + * the referent's absolute position by less than the total ctx growth. Name + * lookup is invariant under mid-list inserts: the binding's name doesn't + * move when adjacent specs get inserted. When the lookup succeeds, return a + * fresh TypeVar with up-to-date `(idx, length)` and the same `referent` + * carried forward. When the lookup fails (name was removed or shadowed), + * fall through to the legacy uniform shift. */ def typeShiftOnContextDiff(ty: Type): ContextState[Type] = ty match { - case TypeVar(_, _, varContextLength) => + case tv: TypeVar => State.inspect { ctx => - val d = varContextLength - getNotes(ctx).toList.length - typeShift(-d, ty) + val notes = getNotes(ctx).toList + val curLen = notes.length + // Verify the recovered binding is still a type abbreviation. If a + // same-named term shadowed the type (unlikely but possible after + // rewrites), the `collect` filters it out and we fall through to + // uniform shift. + val recovered = tv.referent.flatMap(name => + nameToIndex(ctx, name).flatMap(newIdx => + notes.lift(newIdx).collect { case (_, _: TypeAbbBind) => + TypeVar(tv.info, newIdx, curLen, Some(name)) + } + ) + ) + recovered.getOrElse(typeShift(curLen - tv.length, tv)) } case TypeApp(info, ty1, ty2) => for { diff --git a/src/main/scala/core/Desugar.scala b/src/main/scala/core/Desugar.scala index d799217..211cb59 100644 --- a/src/main/scala/core/Desugar.scala +++ b/src/main/scala/core/Desugar.scala @@ -92,7 +92,9 @@ object Desugar { } yield List(typeBind) case FFuncDecl(sig @ FFuncSig(_, FIdentifier(i), tp, _, _), exprs) => { for { - func <- Context.runE(buildFunc(tp, sig, exprs)) + rewritten <- rewriteMainIOReturn(i, sig, exprs) + (effSig, effExprs) = rewritten + func <- Context.runE(buildFunc(tp, effSig, effExprs)) funcBind <- bindTermAbb(i, func) } yield List(funcBind) } @@ -182,6 +184,45 @@ object Desugar { DesugarError.format(DeclarationNotSupportedDesugarError(d.info)) } + def rewriteMainIOReturn( + i: String, + sig: FFuncSig, + exprs: Seq[FExpr] + ): StateEither[(FFuncSig, Seq[FExpr])] = sig.r match { + case ioTy @ FSimpleType(info, FIdentifier("IO"), Some(Seq(innerT))) + if i == TypeChecker.MainFunction => + innerT match { + case FSimpleType(_, FIdentifier("IO"), _) => + DesugarError.format(MainIONestedDesugarError(info)) + case _ => + val tmpName = "__main_io_action" + val ioLet = FLetExpr( + info, + FIdentifier(tmpName), + Some(ioTy), + exprs.takeRight(1) + ) + val execLet = FLetExpr( + info, + FIdentifier("_"), + None, + Seq( + FMethodApp( + info, + FProj(info, FVar(info, tmpName), Seq(FVar(info, "exec"))), + None, + Seq(None) + ) + ) + ) + val newSig = sig.copy(r = FSimpleType(info, FIdentifier("i32"), None)) + val newExprs = + exprs.dropRight(1) :+ ioLet :+ execLet :+ FInt(info, 0) + (newSig, newExprs).pure[StateEither] + } + case _ => (sig, exprs).pure[StateEither] + } + def bindTypeAbb(i: String, t: Type): StateEither[Bind] = EitherT.liftF(Context.addName(i).map(Bind(_, TypeAbbBind(t)))) @@ -898,7 +939,17 @@ object Desugar { }) typeVar <- value match { case (ctx, Some(index)) => - TypeVar(info, index, ctx._1.length).pure[StateEither] + // Loophole K Option 1: stamp every TypeVar with the original + // name as `referent`. At consumption time, + // `typeShiftOnContextDiff` checks the *current* binding kind + // for that name: top-level type abbreviations + // (TypeAbbBind) get name-based recovery to sidestep mid-list + // insert drift; local type-parameter binders (TypeVarBind) + // fall through to uniform shift. The discrimination cannot + // happen here at desugar time because top-level types are + // staged as NameBind first and only promoted to TypeAbbBind + // during type-checking (`checkBindings`). + TypeVar(info, index, ctx._1.length, Some(i)).pure[StateEither] case _ => DesugarError.format(TypeVariableNotFoundDesugarError(info, i)) } } yield typeVar diff --git a/src/main/scala/core/DesugarError.scala b/src/main/scala/core/DesugarError.scala index 3913710..8b5ea5b 100644 --- a/src/main/scala/core/DesugarError.scala +++ b/src/main/scala/core/DesugarError.scala @@ -25,6 +25,7 @@ case class NestedPatternNotSupportedDesugarError(info: Info) extends DesugarError case class DoRequiresYieldExprDesugarError(info: Info) extends DesugarError case class DoExpectsAssignmentDesugarError(info: Info) extends DesugarError +case class MainIONestedDesugarError(info: Info) extends DesugarError object DesugarError { def format[T](error: DesugarError): StateEither[T] = @@ -56,5 +57,10 @@ object DesugarError { consoleError("yield expression not found", info) case DoExpectsAssignmentDesugarError(info) => consoleError("assignment expression expected", info) + case MainIONestedDesugarError(info) => + consoleError( + "`main` cannot return a nested `IO` type; unwrap to a single `IO[T]`", + info + ) }) } diff --git a/src/main/scala/core/Instantiations.scala b/src/main/scala/core/Instantiations.scala index 65e842f..82a4202 100644 --- a/src/main/scala/core/Instantiations.scala +++ b/src/main/scala/core/Instantiations.scala @@ -12,7 +12,11 @@ import core.Terms.* import core.TypeChecker.* import core.Types.* import parser.Info.Info -import core.Desugar.{MethodNamePrefix, SelfTypeName} +import core.Desugar.{ + MethodNamePrefix, + RecursiveFunctionParamPrefix, + SelfTypeName +} import parser.Info.ShowInfo.ShowInfoOps import scala.annotation.tailrec import parser.Info.UnknownInfo @@ -118,7 +122,15 @@ object Instantiations { typeNames <- tys .traverse(Representation.typeToString(_)) .map(_.mkString(BindTypeSeparator)) - baseName = s"$i$BindTypeSeparator$typeNames" + baseName = tys.isEmpty match { + // Empty `tys` => no specialization suffix. Without this guard the + // renderer emitted `$i#` (trailing separator), producing names + // like `!exec#IO#` that GRIN rejects with "non-defined function". + // The bare `$i` form aligns with the original impl-method id from + // `Desugar.toMethodID`, so the call resolves to that bind. + case true => i + case false => s"$i$BindTypeSeparator$typeNames" + } name = cls match { // NOTE: When class is set on the instantiation it points to a type // instances's method. Hence we need to add a method prefix. @@ -159,10 +171,13 @@ object Instantiations { acc.pure[StateEither] case (assocProj @ TermAssocProj(_, ty, method), _, _) if tys.nonEmpty => // Direct static method term (not yet wrapped in TermApp) - // Captures type solutions from innermost TypeAll unwrapping + // Captures type solutions from innermost TypeAll unwrapping. + // See the comment on the `(_: TermApp, assocProj, _)` case below for + // why trait resolution is required here. for { typeName <- EitherT.liftF(getNameFromType(ty)) - methodID = Desugar.toMethodID(method, typeName) + traitOpt <- findMethodTrait(typeName, method) + (methodID, _) <- resolveMethodId(typeName, method, traitOpt) resolvedTys <- tys.traverse(resolveTypeConstructorsInType(_)) } yield acc :+ Instantiation(methodID, assocProj, resolvedTys, cls) case ( @@ -171,10 +186,11 @@ object Instantiations { _ ) if tys.nonEmpty => // Direct method projection with type solutions (not wrapped in - // TermApp). Only fires for method calls with a single value - // argument: multi-argument method calls are caught by the outer - // `TermApp`/`TermMethodProj` case below, so collecting here would - // double-add them. + // TermApp). For trait methods the arity-1 guard avoids + // double-collection with downstream trait-method paths; for + // non-trait impl methods no other path collects, so the guard + // is relaxed there. Without the relaxation, multi-arg impl + // methods would produce a missing specialized bind at GRIN-gen. for { (objType, _) <- pureInfer(obj) typeName <- EitherT.liftF(getNameFromType(objType)) @@ -192,7 +208,16 @@ object Instantiations { case Some(idx) => getType(UnknownInfo, idx).map(t => Some(getValueArity(t))) } - result <- methodArity.contains(1) match { + arityOk = (methodArity, traitOpt) match { + // Non-trait impl method: any non-zero arity collects + // (Loophole M). + case (Some(a), None) => a >= 1 + // Trait method or unknown arity: preserve legacy arity-1 + // guard to avoid double-collection. + case (Some(a), _) => a == 1 + case _ => false + } + result <- arityOk match { case false => acc.pure[StateEither] case true => for { @@ -204,7 +229,8 @@ object Instantiations { } yield acc :+ Instantiation(methodID, methodTerm, finalTys, cls) } } yield result - case (_: TermApp, methodTerm @ TermMethodProj(_, obj, method), _) => + case (_: TermApp, methodTerm @ TermMethodProj(_, obj, method), _) + if tys.nonEmpty => for { (objType, _) <- pureInfer(obj) // type of the object, not the method typeName <- EitherT.liftF(getNameFromType(objType)) @@ -256,10 +282,22 @@ object Instantiations { if tys.nonEmpty => // Handle static method calls with type solutions // Type solutions preserve their full structure (including TypeApp wrappers) - // so that distinct specializations get distinct monomorphized names + // so that distinct specializations get distinct monomorphized names. + // + // Trait-method resolution: when `Type::method` refers to a method + // provided via a `impl Trait for Type` block, the bind id is + // `toTypeInstanceMethodID(method, typeName, traitName) = + // "!method#Type#Trait"`, not `toMethodID(method, typeName) = + // "!method#Type"`. Without consulting `findMethodTrait`, the AssocProj + // path produces an inst whose `i` lacks the trait qualifier, and + // `bindName()` then resolves to a non-existent bind. The closure + // calling that bind in GRIN renders e.g. `unitIO' p143` instead of + // `unitIOMonadi32' p143`, which the GRIN linker rejects as + // "illegal code". for { typeName <- EitherT.liftF(getNameFromType(ty)) - methodID = Desugar.toMethodID(method, typeName) + traitOpt <- findMethodTrait(typeName, method) + (methodID, _) <- resolveMethodId(typeName, method, traitOpt) resolvedTys <- tys.traverse(resolveTypeConstructorsInType(_)) existingInst = acc.find(_.i == methodID) // Look up the method's type arity to limit accumulation @@ -306,6 +344,15 @@ object Instantiations { .traverse(isWellFormed(_)) .map(_.reduce(_ && _)) ) + // Resolve bare TypeVars that point to top-level TypeAbbBinds into + // stable TypeId references. Without this, types captured here + // (e.g. the `T = List` solution of a trait-bounded generic like + // `fmap[A, B, T: Functor]`) carry De Bruijn indices that become + // invalid once monomorphization reshapes the binding context, + // producing names like `fmap#i32#i32#_T125` and dropping the + // inst from `discoverItems` because the stale TypeVar fails + // `isTypeFullyResolved`. + resolvedTys <- tys.traverse(resolveTypeConstructorsInType(_)) fInsts <- acc.filterA(_.isTypeResolved().map(_ == false)) r = fInsts.find(_.term == rootTerm) match { case Some(existing) => @@ -316,7 +363,7 @@ object Instantiations { case _ => acc.map { case i if i.term == rootTerm && isFormedSolution => - Instantiation(i.i, i.term, i.tys ::: tys, i.cls) + Instantiation(i.i, i.term, i.tys ::: resolvedTys, i.cls) case e => e } } @@ -326,10 +373,17 @@ object Instantiations { case (termVar @ TermVar(info, idx, c), _, _) => for { optionName <- EitherT.liftF(State.inspect(indexToName(_, idx))) - name <- optionName match { + rawName <- optionName match { case Some(name) => name.pure[StateEither] case None => TypeError.format(NotFoundTypeError(info)) } + // Strip the recursion-combinator parameter prefix so the inst names + // the top-level generic bind (which is what gets specialized), not + // the inner synthetic ^foo parameter from + // `Desugar.withFixCombinator`. Without this, `bindName()` would + // yield "^foo#i32" while the actual spec bind is "foo#i32", + // leaving the recursive self-call unrewritten by MonoRewrite. + name = rawName.stripPrefix(RecursiveFunctionParamPrefix) // Data constructors are uppercase and don't contain specialized suffix isDataConstructor = isDataConstrName(name) // Skip closure parameters - these are runtime values, not generic functions @@ -424,7 +478,7 @@ object Instantiations { * monomorphization the constructor name is stable across context changes. */ def resolveTypeConstructorsInType(ty: Type): StateEither[Type] = ty match { - case TypeApp(info, tv @ TypeVar(tvInfo, idx, _), ty2) => + case TypeApp(info, tv @ TypeVar(tvInfo, idx, _, _), ty2) => for { binding <- getBinding(tvInfo, idx).recover { case _ => TermAbbBind(TermUnit(UnknownInfo), None) @@ -441,7 +495,7 @@ object Instantiations { resolvedTy1 <- resolveTypeConstructorsInType(ty1) resolvedTy2 <- resolveTypeConstructorsInType(ty2) } yield TypeApp(info, resolvedTy1, resolvedTy2) - case tv @ TypeVar(tvInfo, idx, _) => + case tv @ TypeVar(tvInfo, idx, _, _) => // Resolve bare TypeVars pointing to concrete TypeAbbBind (user-defined ADTs) // to stable TypeId names so they survive context changes in monomorphization. for { diff --git a/src/main/scala/core/Primops.scala b/src/main/scala/core/Primops.scala index 7e9b0bf..514c7e5 100644 --- a/src/main/scala/core/Primops.scala +++ b/src/main/scala/core/Primops.scala @@ -47,11 +47,47 @@ object Primops { paramTypes = List(TypeString(i), TypeString(i)), returnType = TypeUnit(i) ), + PrimopSpec( + fuseName = "_read_stdin", + grinName = "_prim_read_string", + paramTypes = List(TypeUnit(i)), + returnType = TypeString(i) + ), PrimopSpec( fuseName = "int_to_str", grinName = "_prim_int_str", paramTypes = List(TypeInt(i)), returnType = TypeString(i) + ), + PrimopSpec( + fuseName = "_args_count", + grinName = "_prim_args_count", + paramTypes = List(TypeUnit(i)), + returnType = TypeInt(i) + ), + PrimopSpec( + fuseName = "_args_get", + grinName = "_prim_args_get", + paramTypes = List(TypeInt(i)), + returnType = TypeString(i) + ), + PrimopSpec( + fuseName = "_string_char_at", + grinName = "_prim_string_char_at", + paramTypes = List(TypeString(i), TypeInt(i)), + returnType = TypeInt(i) + ), + PrimopSpec( + fuseName = "_string_substring", + grinName = "_prim_string_substring", + paramTypes = List(TypeString(i), TypeInt(i), TypeInt(i)), + returnType = TypeString(i) + ), + PrimopSpec( + fuseName = "_string_len", + grinName = "_prim_string_len", + paramTypes = List(TypeString(i)), + returnType = TypeInt(i) ) ) diff --git a/src/main/scala/core/Representation.scala b/src/main/scala/core/Representation.scala index 2104023..2eeb7c0 100644 --- a/src/main/scala/core/Representation.scala +++ b/src/main/scala/core/Representation.scala @@ -45,10 +45,22 @@ object Representation { buildContext: Boolean = false ): StateEither[String] = t match { - case TypeVar(_, idx, n) => + case TypeVar(_, idx, n, referent) => EitherT(State.inspect { ctx => - Context - .indexToName(ctx, idx) + // Loophole K Option 1: when the referent is set and resolves to a + // top-level TypeAbbBind, prefer it over indexToName(idx). Drifted + // idx may point to an unrelated bind (e.g. a primop or method + // spec) — using the referent yields the user-facing type name. + val notes = Context.getNotes(ctx).toList + val referentName: Option[String] = referent.flatMap { name => + Context.nameToIndex(ctx, name).flatMap { resolvedIdx => + notes.lift(resolvedIdx).collect { case (_, _: TypeAbbBind) => + name + } + } + } + referentName + .orElse(Context.indexToName(ctx, idx)) .orElse( // Fallback for TypeVars with out-of-range indices (e.g., negative // from over-shifting). Use placeholder to avoid crashing during diff --git a/src/main/scala/core/Shifting.scala b/src/main/scala/core/Shifting.scala index eb910b9..5a4ff62 100644 --- a/src/main/scala/core/Shifting.scala +++ b/src/main/scala/core/Shifting.scala @@ -66,7 +66,7 @@ object Shifting { termMap( (info, c, k, n) => if (k >= c) TermVar(info, k - 1, n - 1) else TermVar(info, k, n - 1), - (c, tyT) => typeSubstitute(tyS, c, tyT), + (c, tyT) => typeSubstituteTopAt(tyS, c, tyT), c, t ) @@ -196,6 +196,17 @@ object Shifting { def typeSubstituteTop(tyS: Type, tyT: Type): Type = typeShift(-1, typeSubstitute(typeShift(1, tyS), 0, tyT)) + /** Generalization of `typeSubstituteTop` to substitute at arbitrary depth + * `c`, applying the closing downshift to indices strictly greater than `c`. + * The standard `typeSubstituteTop` is the special case `c = 0`. + * + * Used when a term-level traversal has crossed `c` binders before reaching + * the embedded type, so the eliminated `TermTAbs` slot lives at unified + * index `c` rather than `0`. + */ + def typeSubstituteTopAt(tyS: Type, c: Int, tyT: Type): Type = + typeShiftAbove(-1, c + 1, typeSubstitute(typeShift(1, tyS), c, tyT)) + /** Substitutes `TypeVar` instances having `c` index with `tyS` type in * provided `tyT` type . */ @@ -231,7 +242,12 @@ object Shifting { */ def typeMap(onVar: ShiftVarFunc[Type], c: Int, t: Type): Type = { def iter(c: Int, tyT: Type): Type = tyT match { - case TypeVar(info, x, n) => onVar(info, c, x, n) + case input @ TypeVar(info, x, n, _) => + onVar(info, c, x, n) match { + case out @ TypeVar(outInfo, _, _, None) if outInfo eq info => + out.copy(referent = input.referent) + case other => other + } case TypeEVar(_, _, _) => tyT case TypeAny(_) => tyT case TypeId(_, _) => tyT diff --git a/src/main/scala/core/Syntax.scala b/src/main/scala/core/Syntax.scala index 6a3ec6b..0389b39 100644 --- a/src/main/scala/core/Syntax.scala +++ b/src/main/scala/core/Syntax.scala @@ -16,7 +16,26 @@ object Types { def containsEVar(eV: TypeEVar): Boolean } - case class TypeVar(info: Info, index: Integer, length: Integer) extends Type { + /** A type-variable reference. Carries a De Bruijn `index` plus the `length` + * (context size) at the moment of stamping, so + * `Context.typeShiftOnContextDiff` can compensate for closure entry/exit + * drift via uniform shift. + * + * The optional `referent` field is the original binding **name** at stamping + * time, populated only when the TypeVar references a top-level type + * abbreviation (`TypeAbbBind`). When set, `typeShiftOnContextDiff` consults + * the current context for that name instead of applying the uniform shift — + * sidestepping the Loophole K over-shift caused by mid-list inserts + * displacing the referent's absolute position. Local type-parameter binders + * (`TypeVarBind`) keep `referent = None` because their referents are scoped + * and uniform shift is correct for closure transitions. + */ + case class TypeVar( + info: Info, + index: Integer, + length: Integer, + referent: Option[String] = None + ) extends Type { def containsEVar(eV: TypeEVar): Boolean = false override def isPrimitive: Boolean = true } @@ -93,7 +112,7 @@ object Types { } implicit val showTypeInfo: ShowInfo[Type] = ShowInfo.info(_ match { - case TypeVar(info, _, _) => info + case TypeVar(info, _, _, _) => info case TypeClass(info, _) => info case TypeEVar(info, _, _) => info case TypeId(info, _) => info diff --git a/src/main/scala/core/TypeChecker.scala b/src/main/scala/core/TypeChecker.scala index c62cfe7..cbdffcc 100644 --- a/src/main/scala/core/TypeChecker.scala +++ b/src/main/scala/core/TypeChecker.scala @@ -37,7 +37,7 @@ object TypeChecker { _ <- checkTypeInstanceMethodBinding(bind.i, binding) _ <- checkTypeInstance(bind.i, binding) id <- EitherT.liftF(addBinding(bind.i, binding)) - } yield Bind(id, binding, insts) + } yield Bind(id, binding, insts, closureTypesFromInsts(insts)) ) .flatMap(binds => binds.exists { bind => bind.i == MainFunction } match { @@ -46,6 +46,109 @@ object TypeChecker { } ) + /** Build `Bind.closureTypes` from `Resolution.Closure` insts. + * + * Without this seed, monomorphic binds reach GRIN-gen with + * `closureTypesFromBind = Map.empty`, and the closure-type lookup falls + * through to `pureInfer`. Two failure modes surface there: + * + * 1. **Identity-shape closures** (`λr. r`): the `TermClosure(_, _, None, + * _)` infer arm mints two fresh EVars and the body's `check (TermVar + * 0)` only unifies them with each other. Neither side gets pinned to a + * concrete base, producing a fully-EVar arrow as the apply-dispatch + * key. + * 2. **Body-uses-stale-context closures**: `pureInfer` errors out when the + * body references identifiers whose De-Bruijn indices don't resolve in + * the GRIN-gen Context. The closure's `typeKey` ends up empty, + * `closureMap` lacks an entry, and the dispatch site falls back to the + * all-arity bucket producing a heterogeneous dispatch. + * + * Both cases are addressed by seeding `Bind.closureTypes` with the concrete + * arrow types that the bidirectional `check` already produced for every + * closure inst. The seed is keyed by `inst.i` (the closure parameter name). + * Multiple closures within the same bind that share a parameter name cannot + * be distinguished by name alone — those are filtered out and left to + * `pureInfer`, which is sufficient for the closures it can type. + * + * Only `TypeArrow` heads with no unsolved EVars are kept; partial/EVar + * arrows fall through to the `pureInfer` path which still has the existing + * structure-extraction recovery on top. + */ + def closureTypesFromInsts( + insts: List[Instantiation] + ): Map[String, Type] = + insts + .filter(_.r == Resolution.Closure) + .filter(isIdentityClosureInst) + .flatMap(cInst => + cInst.tys.headOption.collect { + case ty: TypeArrow if !containsUnsolvedEVar(ty) => cInst.i -> ty + } + ) + .toMap + + def isIdentityClosureInst(inst: Instantiation): Boolean = inst.term match { + case TermClosure(_, _, None, TermVar(_, 0, _)) => true + case _ => false + } + + def containsUnsolvedEVar(ty: Type): Boolean = ty match { + case TypeEVar(_, _, _) => true + case TypeArrow(_, t1, t2) => + containsUnsolvedEVar(t1) || containsUnsolvedEVar(t2) + case TypeApp(_, t1, t2) => + containsUnsolvedEVar(t1) || containsUnsolvedEVar(t2) + case TypeAll(_, _, _, _, t) => containsUnsolvedEVar(t) + case TypeRec(_, _, _, t) => containsUnsolvedEVar(t) + case _ => false + } + + /** Structural test for raw generic TypeVars (referent = None). A TypeVar with + * `Some(name)` is a referent-encoded reference (Loophole K Option 1) to a + * type identifier (e.g. `List`) and is treated as resolved — its `typeToKey` + * produces the referent name. A `None` referent indicates a generic + * parameter without a fixed identifier; treat as unresolved. Used in + * `Grin.buildEnv` to filter `closureTypesGlobal` entries. + */ + def hasUnresolvedTypeVar(ty: Type): Boolean = ty match { + case TypeVar(_, _, _, None) => true + case TypeVar(_, _, _, Some(_)) => false + case TypeArrow(_, t1, t2) => + hasUnresolvedTypeVar(t1) || hasUnresolvedTypeVar(t2) + case TypeApp(_, t1, t2) => + hasUnresolvedTypeVar(t1) || hasUnresolvedTypeVar(t2) + case TypeAll(_, _, _, _, t) => hasUnresolvedTypeVar(t) + case TypeRec(_, _, _, t) => hasUnresolvedTypeVar(t) + case _ => false + } + + /** Structural test for an already-monomorphized ADT shape. Used by the + * `TermTApp` infer arm to recognize that `tyT1 = TypeRec(TypeVariant)` with + * no free type variables and no unsolved EVars is a residual + * type-application on a concrete constructor (e.g. `Nil[Node]` after `Nil` + * was already specialized to `List[Node]`). The `[T]` is then a no-op rather + * than an error. + */ + def isMonomorphicAdt(ty: Type): Boolean = ty match { + case _: TypeRec => !containsUnsolvedEVar(ty) && !hasUnresolvedTypeVar(ty) + case _ => false + } + + /** True iff `ty` is an arrow whose param AND return are both unsolved EVars + * (the `λr. r` identity-closure pattern that produces the pathological + * `"TypeEVar->TypeEVar"` typeKey). Used in `Grin.scala`'s closure-type + * resolution to detect when `pureInfer`'s result is the "fully unsolved" + * shape and the bidirectionally-checked seed should be preferred. + * + * Does *not* match arrows with one concrete side (e.g. `EVar -> i32`), + * preserving pre-existing dispatch behavior for `λ_. concrete_body` closures + * whose parameter stays unsolved because it's unused. + */ + def isFullyEVarArrow(ty: Type): Boolean = ty match { + case TypeArrow(_, _: TypeEVar, _: TypeEVar) => true + case _ => false + } + def checkBinding( b: Binding ): StateEither[(Binding, List[Instantiation])] = b match { @@ -63,13 +166,17 @@ object TypeChecker { def pureInfer(t: Term)(implicit checking: Boolean = true - ): StateEither[(Type, List[Instantiation])] = for { - m <- EitherT.liftF(addMark("p")) - (iT, insts) <- infer(t) - aT <- apply(iT)(shift = false) - aI <- insts.traverse(applyInst(_)(shift = false)) - _ <- EitherT.liftF(peel(m)) - } yield (aT, aI) + ): StateEither[(Type, List[Instantiation])] = { + val body: StateEither[(Type, List[Instantiation])] = for { + (iT, insts) <- infer(t) + aT <- apply(iT)(shift = false) + aI <- insts.traverse(applyInst(_)(shift = false)) + } yield (aT, aI) + for { + m <- EitherT.liftF(addMark("p")) + result <- EitherT(body.value.flatMap(r => peel(m).map(_ => r))) + } yield result + } /** Infers a type for `exp` with input context `ctx`. * @return @@ -119,11 +226,21 @@ object TypeChecker { resolvedType <- EitherT.liftF(State.inspect { (ctx: Context) => resolveTypeConstructors(eAS1, ctx) }) + // Store the closure's full arrow type (param + return) in the + // instantiation's `tys`. Storing only the param type loses the + // return-type witness, which downstream `closureTypesMap` builds + // and the GRIN-side `closureTypesFromBind` consumer need to + // compute a concrete-arrow typeKey. Without it, the lookup query + // falls through to the all-arity branch and the GRIN HPT analyzer + // rejects the merged type-env on shared closure bodies. Existing + // consumers already match on `TypeArrow` and extract what they + // need, so the broader form is structurally backward-compatible. closureInst = Instantiation( i = variable, // Use variable name as identifier term = TermClosure(info, variable, None, expr), - tys = - List(resolvedType), // The resolved parameter type with TypeIds + tys = List( + TypeArrow(info, resolvedType, typeShift(-1, eCS)) + ), cls = List(), r = Resolution.Closure ) @@ -209,13 +326,34 @@ object TypeChecker { case TermAssocProj(info, ty, method) => for { tyS <- EitherT.liftF(simplifyType(ty)) - rootTypeVarOption = findRootTypeVar(ty) - typeBounds <- EitherT.liftF( - getTypeBounds(rootTypeVarOption.getOrElse(tyS)) - ) - // If method is already specialized, extract the base method name - baseMethod = SpecializedMethodUtils.extractBaseMethodName(method) - assocMethodType <- inferMethod(ty, tyS, typeBounds, baseMethod, info) + specializedType <- SpecializedMethodUtils.isSpecializedMethod( + method + ) match { + case false => + EitherT.rightT[ContextState, Error](Option.empty[Type]) + case true => + for { + idxOpt <- EitherT.liftF( + State.inspect((ctx: Context) => nameToIndex(ctx, method)) + ) + ty <- idxOpt.traverse(getType(info, _)) + } yield ty + } + assocMethodType <- specializedType match { + case Some(ty) => ty.pure[StateEither] + case None => + for { + rootTypeVarOption <- EitherT.rightT[ContextState, Error]( + findRootTypeVar(ty) + ) + typeBounds <- EitherT.liftF( + getTypeBounds(rootTypeVarOption.getOrElse(tyS)) + ) + baseMethod = + SpecializedMethodUtils.extractBaseMethodName(method) + ty <- inferMethod(ty, tyS, typeBounds, baseMethod, info) + } yield ty + } } yield (assocMethodType, Nil) case TermFix(info, t1) => for { @@ -268,7 +406,7 @@ object TypeChecker { (TypeAll(info, v, kind, cls, t), insts) } } yield ty - case TermTApp(info, expr, ty2) => + case t @ TermTApp(info, expr, ty2) => for { k2 <- kindOf(ty2) (ty1, insts) <- pureInfer(expr) @@ -283,10 +421,20 @@ object TypeChecker { // After monomorphization, a `TermTApp(TermVar(f), T)` may // reference an already-specialized bind whose type has already // had the TypeAll unwrapped — the TermTApp wrapper is residual. - // Return the already-specialized type unchanged instead of - // erroring. + // Return the un-simplified `ty1` so a recursive ADT type stays + // in compact `TypeApp(TypeId, T)` form rather than the unfolded + // `TypeRec(TypeVariant(...))` produced by `simplifyType`. The + // unfolded form would propagate up and fail subtyping against + // declared compact-form types. case _ if !checking => - tyT1.pure[StateEither] + ty1.pure[StateEither] + // A monomorphized constructor reaches here with `tyT1` = + // `TypeRec(TypeVariant(...))` rather than a `TypeAll`. The + // TermTApp wrapper is residual — passing `ty1` through preserves + // the concrete type. Guarded on `isMonomorphicAdt` so a + // type-applied primitive (e.g. `tyT1 = TypeInt`) still errors. + case _ if isMonomorphicAdt(tyT1) => + ty1.pure[StateEither] case _ => TypeError.format( TypeArgumentsNotAllowedTypeError( @@ -295,7 +443,20 @@ object TypeChecker { ) ) } - } yield (ty, insts) + // A bare `TermTApp(TermVar, T)` in argument or value position + // carries no inst by default — `infer` for application sites only + // calls `Instantiations.build` on the function position, so a + // TermTApp elsewhere falls through. Without an inst, + // `replaceInstantiations` cannot redirect the TermVar to its + // specialised bind, leaving its De Bruijn index stale. As the + // bind list grows (insertions from other specialisations), the + // stale index resolves to an unrelated binding. + // + // Calling `Instantiations.build` directly on the TermTApp here + // generates the inst. Build is a no-op for non-TermVar heads so + // this is safe in all callsites. + buildInsts <- Instantiations.build(t, Nil, insts) + } yield (ty, buildInsts) case TermMatch(info, exprTerm, cases) => for { // Get the type of the expression to match. @@ -654,8 +815,27 @@ object TypeChecker { } def computeType(ty: Type): StateOption[Type] = ty match { - case TypeVar(_, idx, ctxLen) => - getTypeAbb(idx).orElse(recoverStaleTypeVar(idx, ctxLen)) + case TypeVar(_, idx, ctxLen, referent) => + // Loophole K Option 1: prefer name-based recovery via referent before + // falling back to uniform-shift `recoverStaleTypeVar`. The referent is + // the original top-level type name stamped at desugar time. After + // mid-list bind inserts (monomorphization specs), the stored idx may + // point to a wrong slot — name lookup recovers the current idx. + val nameRecovered: StateOption[Type] = referent match { + case Some(name) => + OptionT + .liftF[ContextState, Option[Int]](State.inspect { (ctx: Context) => + nameToIndex(ctx, name) + }) + .flatMap { + case Some(newIdx) if newIdx != idx => getTypeAbb(newIdx) + case _ => OptionT.none + } + case None => OptionT.none + } + getTypeAbb(idx) + .orElse(nameRecovered) + .orElse(recoverStaleTypeVar(idx, ctxLen)) case TypeApp(info, TypeAbs(_, _, tyT12), tyT2) => OptionT.some[ContextState](typeSubstituteTop(tyT2, tyT12)) case TypeApp(info, TypeId(_, name), tyT2) => @@ -680,9 +860,9 @@ object TypeChecker { @tailrec def findRootType(ty: Type): Type = ty match { - case v @ TypeVar(_, _, _) => v - case TypeApp(_, ty1, ty2) => findRootType(ty1) - case _ => ty + case v @ TypeVar(_, _, _, _) => v + case TypeApp(_, ty1, ty2) => findRootType(ty1) + case _ => ty } @tailrec @@ -702,7 +882,7 @@ object TypeChecker { }) def getTypeBounds(ty: Type): ContextState[List[TypeClass]] = ty match { - case TypeVar(info, index, _) => + case TypeVar(info, index, _, _) => Context .getBinding(UnknownInfo, index) .getOrElse(List()) @@ -754,8 +934,8 @@ object TypeChecker { _ <- checkKindStar(ty1) _ <- checkKindStar(ty1) } yield KindStar - case TypeVar(info, idx, _) => getKind(info, idx) - case TypeId(info, name) => + case TypeVar(info, idx, _, _) => getKind(info, idx) + case TypeId(info, name) => // TypeId might be a monomorphized type or a type constructor // Try to look it up in the context for { @@ -905,6 +1085,8 @@ object TypeChecker { def check( exp: Term, t: Type + )(implicit + checking: Boolean = true ): StateEither[(List[Instantiation], List[TypeESolutionBind])] = (exp, t) match { // 1I :: ((), 1) @@ -1286,11 +1468,31 @@ object TypeChecker { } yield b1 && b2 case (TypeRec(_, x1, k1, tyS1), TypeRec(_, _, k2, tyT1)) if k1 == k2 => Context.addName(x1).flatMap(_ => isTypeEqual(tyS1, tyT1)) - case (TypeVar(_, idx1, len1), TypeVar(_, idx2, len2)) => + case ( + TypeVar(_, idx1, _, Some(name1)), + TypeVar(_, idx2, _, Some(name2)) + ) if name1 == name2 => + // Loophole K Option 1: when both TypeVars carry the same referent + // and resolve to a top-level type abbreviation in the current + // context, treat them as equal regardless of idx drift. This + // sidesteps mid-list-insert displacement that the uniform shift + // can't compensate for. Local type-parameter binders + // (TypeVarBind) still require idx match because their names can + // shadow across nested scopes. + State.inspect { ctx => + val notes = Context.getNotes(ctx).toList + val resolvesToTypeAbb = (n: Int) => + (n >= 0 && n < notes.length) && (notes(n)._2 match { + case _: TypeAbbBind => true + case _ => false + }) + (idx1 == idx2) || (resolvesToTypeAbb(idx1) && resolvesToTypeAbb(idx2)) + } + case (TypeVar(_, idx1, _, _), TypeVar(_, idx2, _, _)) => (idx1 == idx2).pure - case (TypeVar(_, idx, _), TypeId(_, id)) => + case (TypeVar(_, idx, _, _), TypeId(_, id)) => State.inspect(ctx => Context.indexToName(ctx, idx).contains(id)) - case (TypeId(_, id), TypeVar(_, idx, _)) => + case (TypeId(_, id), TypeVar(_, idx, _, _)) => State.inspect(ctx => Context.indexToName(ctx, idx).contains(id)) case (TypeRecord(_, f1), TypeRecord(_, f2)) if f1.length == f2.length => f1.traverse { case (l1, tyT1) => @@ -1346,7 +1548,7 @@ object TypeChecker { case TypeApp(info, ctor, param) => // Resolve constructor if it's a TypeVar val resolvedCtor = ctor match { - case TypeVar(_, idx, _) => + case TypeVar(_, idx, _, _) => // Look up the binding at this index Context.indexToName(ctx, idx) match { case Some(name) => diff --git a/src/main/scala/parser/Expressions.scala b/src/main/scala/parser/Expressions.scala index 94e7b10..ca6d3b1 100644 --- a/src/main/scala/parser/Expressions.scala +++ b/src/main/scala/parser/Expressions.scala @@ -348,7 +348,13 @@ abstract class Expressions(fileName: String) extends Types(fileName) { ) } def String = { - def Raw = rule(!'\"' ~ ANY) + // Two-char escape units come first so the parser commits both bytes + // before the single-byte fallback runs. `\\` (escaped backslash) must + // be its own alternative — without it, `\\"` would be parsed as one + // raw `\` followed by an `\"` escape, swallowing the closing quote. + def Raw = rule { + ('\\' ~ '\\') | ('\\' ~ '\"') | (!'\"' ~ ANY) + } rule { info ~ '"' ~ capture(Raw.*) ~ '"' ~> FString.apply } diff --git a/src/test/scala/CompilerTests.scala b/src/test/scala/CompilerTests.scala index 339f3ab..5508e2e 100644 --- a/src/test/scala/CompilerTests.scala +++ b/src/test/scala/CompilerTests.scala @@ -5,7 +5,6 @@ import scala.concurrent.duration.Duration import cats.effect.{IO, Resource} import java.nio.file.{Files, Path, Paths} import scala.sys.process.* -import cats.effect.ExitCode abstract class CompilerTests extends FunSuite { import CompilerTests.* @@ -14,14 +13,18 @@ abstract class CompilerTests extends FunSuite { /** Asserts fuse code is type checked. */ def fuse(code: String, expected: Output = CheckOutput(None)) = expected match { - case CheckOutput(s) => assertCheck(code, s) - case BuildOutput(s) => assertBuild(code, s) - case ExecutableOutput(stdout, exitCode, stdlib) => - assertExecutable(code, stdout, exitCode, stdlib) + case CheckOutput(s, stdlib) => assertCheck(code, s, stdlib) + case BuildOutput(s, stdlib) => assertBuild(code, s, stdlib) + case ExecutableOutput(stdout, exitCode, stdlib, args) => + assertExecutable(code, stdout, exitCode, stdlib, args) } - def assertCheck(code: String, expectedError: Option[String]) = - (check(code), expectedError) match { + def assertCheck( + code: String, + expectedError: Option[String], + includeStdlib: Boolean = true + ) = + (check(code, includeStdlib = includeStdlib), expectedError) match { case (t, None) => assert(t.isRight, s"\n${t.merge}") case (t, Some(error)) if t.isLeft => assert(t.merge.contains(error), s"\n${error} not in:\n${t.merge}") @@ -29,8 +32,12 @@ abstract class CompilerTests extends FunSuite { assert(false, s"\ncheck passed, error not thrown: '${error}'") } - def assertBuild(code: String, expectedGrinCode: String) = - build(code) match { + def assertBuild( + code: String, + expectedGrinCode: String, + includeStdlib: Boolean = false + ) = + build(code, includeStdlib = includeStdlib) match { case Right(grinCode) => assertNoDiff( grinCode, @@ -45,11 +52,12 @@ abstract class CompilerTests extends FunSuite { code: String, expectedStdout: String, expectedExitCode: Int = 0, - includeStdlib: Boolean = false + includeStdlib: Boolean = false, + args: List[String] = Nil ) = { import cats.effect.unsafe.implicits.global - execute(code, includeStdlib).unsafeRunSync() match { + execute(code, includeStdlib, args).unsafeRunSync() match { case Right(result) => assertEquals( result.exitCode, @@ -211,6 +219,25 @@ fun main() -> Unit """ ) } + test("check top-level generic recursive list reverse with stdlib") { + fuse( + """ +type List[T]: + Cons(h: T, t: List[T]) + Nil + +fun list_reverse_acc[T](xs: List[T], acc: List[T]) -> List[T] + match xs: + Cons(h, t) => list_reverse_acc(t, Cons(h, acc)) + Nil => acc + +fun main() -> i32 + let l = Cons(1, Cons(2, Nil)) + let r = list_reverse_acc(l, Nil[i32]) + 0 + """ + ) + } test("check generic list with map_2 non-recursive") { fuse(""" type List[A]: @@ -653,6 +680,27 @@ fun main() -> i32 Nil => 1 """) } + test( + "check non-generic top-level fn with self-recursive let-bound closure on List[str]" + ) { + fuse(""" +type List[A]: + Cons(h: A, t: List[A]) + Nil + +fun join(lines_: List[str]) -> str + let iter = (lines: List[str], acc: str) => { + match lines: + Nil => acc + Cons(h, t) => iter(t, acc + h) + } + iter(lines_, "") + +fun main() -> i32 + let r = join(Nil[str]) + 0 + """) + } test( "check recursive closure inference with match statement for list type annotations required" ) { @@ -1006,7 +1054,8 @@ fun main() -> i32 } test("check generic trait monad with default implementation") { - fuse(""" + fuse( + """ trait Monad[A]: fun unit[T](a: T) -> Self[T]; @@ -1032,7 +1081,9 @@ fun main() -> i32 let o = Some(5) o.map(a => a + 1) 0 - """) + """, + CheckOutput(None, includeStdlib = false) + ) } test("check generic trait monad for state") { @@ -1076,7 +1127,8 @@ fun main() -> i32 """) } test("check generic traits monad + show with default implementation") { - fuse(""" + fuse( + """ trait Monad[A]: fun unit[B](a: B) -> Self[B]; @@ -1110,7 +1162,9 @@ fun main() -> i32 let o = Some(5) o.map(a => a + 1) 0 - """) + """, + CheckOutput(None, includeStdlib = false) + ) } test("check invalid type classes used for type param for different kinds") { fuse( @@ -1171,7 +1225,8 @@ fun main() -> i32 } test("check do expr") { - fuse(""" + fuse( + """ trait Monad[A]: fun unit[A](a: A) -> Self[A]; @@ -1207,7 +1262,9 @@ fun main() -> i32 match d: Some(v) => v _ => 0 - """) + """, + CheckOutput(None, includeStdlib = false) + ) } test("check do expr invalid do expr") { @@ -1484,7 +1541,8 @@ fun main() -> i32 CheckOutput( Some( "expected type of `i32`, found `str`" - ) + ), + includeStdlib = false ) ) @@ -1503,6 +1561,40 @@ fun main() -> Unit """) } + test("check main returning IO[Unit] with do notation") { + fuse(""" +fun main() -> IO[Unit] + do: + _ <- print("a\n") + _ <- print("b\n") + () + """) + } + test("check main returning IO[Unit] with single expression") { + fuse(""" +fun main() -> IO[Unit] + print("hi\n") + """) + } + test("check main returning IO[i32]") { + fuse(""" +fun main() -> IO[i32] + do: + _ <- print("answer:\n") + 42 + """) + } + test("reject main returning nested IO") { + fuse( + """ +fun main() -> IO[IO[Unit]] + unit(print("nope")) + """, + CheckOutput( + Some("`main` cannot return a nested `IO` type") + ) + ) + } test("check generic list map with unit return type") { fuse( """ @@ -2020,6 +2112,33 @@ fun main() -> i32 """) } + test("check escaped double-quote in string literal") { + fuse(""" +fun main() -> str + "\"" + """) + } + + test("check Nil[T] nested inside Cons literal is accepted") { + fuse(""" +type Token: + TID(s: str) + +type Node: + NTok(h: Token) + NSep + +fun foo(xs: List[Node]) -> i32 + 0 + +fun bar(h: Token) -> i32 + foo(Cons(NTok(h), Nil[Node])) + +fun main() -> i32 + bar(TID("a")) + """) + } + } class CompilerBuildTests extends CompilerTests { @@ -2381,6 +2500,263 @@ grinMain _6 = """) ) } + test("build top-level generic recursive list reverse") { + fuse( + """ +type List[T]: + Cons(h: T, t: List[T]) + Nil + +fun list_reverse_acc[T](xs: List[T], acc: List[T]) -> List[T] + match xs: + Cons(h, t) => list_reverse_acc(t, Cons(h, acc)) + Nil => acc + +fun main() -> i32 + let l = Cons(1, Cons(2, Nil)) + let r = list_reverse_acc(l, Nil[i32]) + 0 + """, + BuildOutput("""Cons'i32 h0 t1 = + store (CConsi32 h0 t1) + +Nil'i32 = store (CNili32) + +list_reverse_acc'i32 xs4 acc5 = + p8 <- fetch xs4 + case p8 of + (CConsi32 h8 t'9) -> + p11 <- Cons'i32 h8 acc5 + p12 <- list_reverse_acc'i32 t'9 p11 + pure p12 + #default -> + pure acc5 + +grinMain _12 = + p14 <- Nil'i32 + p15 <- Cons'i32 2 p14 + l15 <- Cons'i32 1 p15 + p17 <- Nil'i32 + r17 <- list_reverse_acc'i32 l15 p17 + pure 0""") + ) + } + test("build non-recursive top-level generic head_or with stdlib") { + fuse( + """ +type List[T]: + Cons(h: T, t: List[T]) + Nil + +fun head_or[T](xs: List[T], dflt: T) -> T + match xs: + Cons(h, _) => h + Nil => dflt + +fun main() -> i32 + let l = Cons(42, Nil) + let r = head_or(l, 0) + r + """, + BuildOutput( + """ffi pure + _prim_bool_and :: T_Bool -> T_Bool -> T_Bool + _prim_bool_or :: T_Bool -> T_Bool -> T_Bool + +char_at s0 i1 = + _prim_string_char_at s0 i1 + +substring s2 start3 end4 = + _prim_string_substring s2 start3 end4 + +str_len s5 = + _prim_string_len s5 + +is_whitespace c6 = + p8 <- _prim_int_eq c6 32 + p9 <- _prim_int_eq c6 9 + p10 <- _prim_bool_or p8 p9 + p11 <- _prim_int_eq c6 10 + p12 <- _prim_bool_or p10 p11 + p13 <- _prim_int_eq c6 13 + _prim_bool_or p12 p13 + +is_digit c13 = + p15 <- _prim_int_ge c13 48 + p16 <- _prim_int_le c13 57 + _prim_bool_and p15 p16 + +is_lower c16 = + p18 <- _prim_int_ge c16 97 + p19 <- _prim_int_le c16 122 + _prim_bool_and p18 p19 + +is_upper c19 = + p21 <- _prim_int_ge c19 65 + p22 <- _prim_int_le c19 90 + _prim_bool_and p21 p22 + +is_alpha c22 = + p24 <- is_lower c22 + p25 <- is_upper c22 + _prim_bool_or p24 p25 + +is_alnum c25 = + p27 <- is_alpha c25 + p28 <- is_digit c25 + _prim_bool_or p27 p28 + +is_ident_start c28 = + p30 <- is_alpha c28 + p31 <- _prim_int_eq c28 95 + _prim_bool_or p30 p31 + +is_ident_cont c31 = + p33 <- is_alnum c31 + p34 <- _prim_int_eq c31 95 + _prim_bool_or p33 p34 + +repeat s34 n35 = + p38 <- _prim_int_le n35 0 + case p38 of + #True -> + pure #"" + #False -> + p39 <- _prim_int_sub n35 1 + p40 <- repeat s34 p39 + p41 <- _prim_string_concat s34 p40 + pure p41 + +Cons'str h41 t42 = + store (CConsstr h41 t42) + +Cons'i32 h45 t46 = + store (CConsi32 h45 t46) + +Nil'str = store (CNilstr) + +Nil'i32 = store (CNili32) + +MkIO'Unit run''49 = + p52 <- store run''49 + store (CMkIOUnit p52) + +MkIO'str run''52 = + p55 <- store run''52 + store (CMkIOstr p55) + +MkIO'Liststr run''55 = + p58 <- store run''55 + store (CMkIOListstr p58) + +print s58 = + p67 <- pure (P1c60 s58) + MkIO'Unit p67 + +c60 s5861 _61 = + _prim_string_print s5861 + +read path67 = + p76 <- pure (P1c69 path67) + MkIO'str p76 + +c69 path6770 _70 = + _prim_file_read path6770 + +write path76 content77 = + p87 <- pure (P1c79 path76 content77) + MkIO'Unit p87 + +c79 path7680 content7781 _81 = + _prim_file_write path7680 content7781 + +get_args _87 = + p139 <- pure (P1c130 ) + MkIO'Liststr p139 + +collect_args89 i89 acc90 = + p93 <- _prim_int_lt i89 0 + case p93 of + #True -> + pure acc90 + #False -> + p94 <- _prim_int_sub i89 1 + p95 <- _prim_args_get i89 + p96 <- Cons'str p95 acc90 + p97 <- collect_args89 p94 p96 + pure p97 + +c130 _130 = + p132 <- _prim_args_count 0 + p133 <- _prim_int_sub p132 1 + p134 <- Nil'i32 + collect_args89 p133 p134 + +read_stdin_acc acc139 = + line141 <- _prim_read_string 0 + p144 <- str_len line141 + p145 <- _prim_int_eq p144 0 + case p145 of + #True -> + pure acc139 + #False -> + p146 <- _prim_string_concat acc139 line141 + p147 <- read_stdin_acc p146 + pure p147 + +read_stdin _147 = + p155 <- pure (P1c149 ) + MkIO'str p155 + +c149 _149 = + read_stdin_acc #"" + +Cons'str h155 t156 = + store (CConsstr h155 t156) + +Cons'i32 h159 t160 = + store (CConsi32 h159 t160) + +Nil'str = store (CNilstr) + +Nil'i32 = store (CNili32) + +head_or'i32 xs163 dflt164 = + p167 <- fetch xs163 + case p167 of + (CConsi32 h167 _'168) -> + pure h167 + #default -> + pure dflt164 + +grinMain _169 = + p171 <- Nil'i32 + l171 <- Cons'i32 42 p171 + r172 <- head_or'i32 l171 0 + pure r172 + +apply1_TypeEVar_to_str p174 p175 = + case p174 of + (P1c69 p176) -> + c69 p176 p175 + (P1c149) -> + c149 p175 + +apply1_TypeEVar_to_unit p177 p178 = + case p177 of + (P1c60 p179) -> + c60 p179 p178 + (P1c79 p180 p181) -> + c79 p180 p181 p178 + +apply1_unit_to_List_str p182 p183 = + case p182 of + (P1c130) -> + c130 p183""", + includeStdlib = true + ) + ) + } test("build function with params using generics") { fuse( """ @@ -3407,6 +3783,8 @@ Cons'Unit h4 t5 = Nil'i32 = store (CNili32) +Nil'Unit = store (CNilUnit) + foldrightListi32Listi32' as8 z9 f''10 = p13 <- fetch as8 case p13 of @@ -3441,7 +3819,7 @@ c32 f''2933 h34 t35 = Cons'i32 p37 t35 mapListi32Unit' self39 f''40 = - p42 <- Nil'i32 + p42 <- Nil'Unit p49 <- store f''40 p50 <- pure (P2c43 p49) foldrightListi32ListUnit' self39 p42 p50 @@ -4014,7 +4392,7 @@ execIOUnit' self21 = case p24 of (CMkIOUnit p26) -> f24 <- fetch p26 - p27 <- apply1_unit_to_unit f24 0 + p27 <- apply1_TypeEVar_to_unit f24 0 pure p27 mapMonadIOUniti32' self27 f''28 = @@ -4039,72 +4417,92 @@ c39 self3640 f''3742 _43 = execIOi32' b46 flatmapIOMonadUniti32' self49 f''50 = - p61 <- store f''50 - p62 <- pure (P1c52 self49 p61) - MkIO'i32 p62 + p62 <- store f''50 + p63 <- pure (P1c52 self49 p62) + MkIO'i32 p63 c52 self4953 f''5055 _56 = f''505556 <- fetch f''5055 a57 <- execIOUnit' self4953 - b59 <- apply1_unit_to_IO_i32 f''505556 a57 - execIOi32' b59 - -unitIOMonadi32' a62 = - p67 <- pure (P1c64 a62) - MkIO'i32 p67 - -c64 a6265 _65 = - pure a6265 - -grinMain _67 = - p75 <- pure (P1c69 ) - p76 <- MkIO'str p75 - p105 <- pure (P1c77 ) - program105 <- flatmapIOMonadstri32' p76 p105 - execIOi32' program105 - -c69 _69 = + p59 <- case f''505556 of + (P1c89) -> + apply1_unit_to_i32 f''505556 a57 + (P1c78) -> + apply1_str_to_IO_i32 f''505556 a57 + (P1c70) -> + apply1_TypeEVar_to_str f''505556 a57 + (P1c30 _f''505556_c30_0) -> + apply1_IO_to_IO_i32 f''505556 a57 + (P1c52 _f''505556_c52_0 _f''505556_c52_1) -> + apply1_unit_to_i32 f''505556 a57 + (P1c65 _f''505556_c65_0) -> + apply1_unit_to_i32 f''505556 a57 + (P1c80 _f''505556_c80_0) -> + apply1_TypeEVar_to_unit f''505556 a57 + (P1c39 _f''505556_c39_0 _f''505556_c39_1) -> + apply1_unit_to_i32 f''505556 a57 + b60 <- pure p59 + execIOi32' b60 + +unitIOMonadi32' a63 = + p68 <- pure (P1c65 a63) + MkIO'i32 p68 + +c65 a6366 _66 = + pure a6366 + +grinMain _68 = + p76 <- pure (P1c70 ) + p77 <- MkIO'str p76 + p106 <- pure (P1c78 ) + program106 <- flatmapIOMonadstri32' p77 p106 + execIOi32' program106 + +c70 _70 = pure #"hi" -c77 s77 = - p86 <- pure (P1c79 s77) - p87 <- MkIO'Unit p86 - p94 <- pure (P1c88 ) - mapMonadIOUniti32' p87 p94 +c78 s78 = + p87 <- pure (P1c80 s78) + p88 <- MkIO'Unit p87 + p95 <- pure (P1c89 ) + mapMonadIOUniti32' p88 p95 -c79 s7780 _80 = - _prim_string_print s7780 +c80 s7881 _81 = + _prim_string_print s7881 -c88 _88 = +c89 _89 = pure 0 -apply1_TypeEVar_to_str p107 p108 = - case p107 of - (P1c69) -> - c69 p108 +apply1_IO_to_IO_i32 p108 p109 = + case p108 of + (P1c30 p110) -> + c30 p110 p109 -apply1_str_to_IO_i32 p109 p110 = - case p109 of - (P1c77) -> - c77 p110 - -apply1_unit_to_IO_i32 p111 p112 = +apply1_TypeEVar_to_str p111 p112 = case p111 of - (P1c30 p113) -> - c30 p113 p112 + (P1c70) -> + c70 p112 + +apply1_TypeEVar_to_unit p113 p114 = + case p113 of + (P1c80 p115) -> + c80 p115 p114 -apply1_unit_to_i32 p114 p115 = - case p114 of - (P1c39 p116 p117) -> - c39 p116 p117 p115 - (P1c52 p118 p119) -> - c52 p118 p119 p115 - (P1c64 p120) -> - c64 p120 p115 - (P1c79 p121) -> - c79 p121 p115 - (P1c88) -> - c88 p115""") +apply1_str_to_IO_i32 p116 p117 = + case p116 of + (P1c78) -> + c78 p117 + +apply1_unit_to_i32 p118 p119 = + case p118 of + (P1c39 p120 p121) -> + c39 p120 p121 p119 + (P1c52 p122 p123) -> + c52 p122 p123 p119 + (P1c65 p124) -> + c65 p124 p119 + (P1c89) -> + c89 p119""") ) } @@ -4163,18 +4561,24 @@ fun main() -> i32 _print(describe(o)) 0 """, - BuildOutput("""describe o0 = - p3 <- fetch o0 - case p3 of - (CMySomei32 v3) -> + BuildOutput("""MySome'i32 t10 = + store (CMySomei32 t10) + +MyNone'i32 = store (CMyNonei32) + +describe o2 = + p5 <- fetch o2 + case p5 of + (CMySomei32 v5) -> pure #"some" #default -> pure #"none" -grinMain _4 = - o7 <- store (CMyNonei32) - p9 <- describe o7 - _10 <- _prim_string_print p9 +grinMain _6 = + p8 <- MyNone'i32 + o10 <- pure p8 + p12 <- describe o10 + _13 <- _prim_string_print p12 pure 0""") ) } @@ -4584,6 +4988,29 @@ fun main() -> i32 ExecutableOutput("Hello World\n5") ) } + test("execute top-level generic recursive list reverse") { + fuse( + """ +type List[T]: + Cons(h: T, t: List[T]) + Nil + +fun list_reverse_acc[T](xs: List[T], acc: List[T]) -> List[T] + match xs: + Cons(h, t) => list_reverse_acc(t, Cons(h, acc)) + Nil => acc + +fun main() -> i32 + let l = Cons(1, Cons(2, Nil)) + let r = list_reverse_acc(l, Nil[i32]) + match r: + Cons(h, _) => _print(int_to_str(h)) + Nil => _print("nil") + 0 + """, + ExecutableOutput("2") + ) + } test("execute function with params using generics") { fuse( """ @@ -5559,6 +5986,27 @@ fun main() -> i32 ExecutableOutput("Hello IO!", includeStdlib = true) ) } + test("execute main returning IO[Unit] single print") { + fuse( + """ +fun main() -> IO[Unit] + print("Hello IO!") + """, + ExecutableOutput("Hello IO!", includeStdlib = true) + ) + } + test("execute main returning IO[Unit]") { + fuse( + """ +fun main() -> IO[Unit] + do: + _ <- print("a\n") + _ <- print("b\n") + () + """, + ExecutableOutput("a\nb", includeStdlib = true) + ) + } test("execute io unit and flat_map") { fuse( """ @@ -5699,6 +6147,101 @@ fun main() -> i32 ) } + test("execute fusefmt formats messy file") { + import scala.io.Source + val fmtSource = + Source.fromFile("examples/fusefmt.fuse").mkString + val tempInput = + java.nio.file.Files.createTempFile("fusefmt-input-", ".fuse") + val messy = + "fun main() -> i32 \n let x = 1 \n\n\n\n x\n" + java.nio.file.Files.write(tempInput, messy.getBytes) + try + fuse( + fmtSource, + ExecutableOutput( + "fun main() -> i32\n let x = 1\n\n\n x", + includeStdlib = true, + args = List(tempInput.toString) + ) + ) + finally java.nio.file.Files.deleteIfExists(tempInput) + } + + test("execute stdlib string helpers") { + fuse( + """ +fun main() -> i32 + _print(int_to_str(char_at("hello", 1))) + _print(" ") + _print(substring("hello world", 6, 11)) + _print(" ") + _print(int_to_str(str_len("abc"))) + _print(" ") + match is_whitespace(32): + true => _print("ok") + false => _print("fail") + 0 + """, + ExecutableOutput("101 world 3 ok", includeStdlib = true) + ) + } + + test("execute args count returns user arg count") { + fuse( + """ +fun main() -> i32 + _print(int_to_str(_args_count(()))) + 0 + """, + ExecutableOutput("3", args = List("a", "b", "c")) + ) + } + + test("execute args count returns 0 with no extra args") { + fuse( + """ +fun main() -> i32 + _print(int_to_str(_args_count(()))) + 0 + """, + ExecutableOutput("0") + ) + } + + test("execute args get returns first user arg") { + fuse( + """ +fun main() -> i32 + _print(_args_get(0)) + 0 + """, + ExecutableOutput("hello", args = List("hello")) + ) + } + + test("execute string char_at returns codepoint at index") { + fuse( + """ +fun main() -> i32 + _print(int_to_str(_string_char_at("hello", 1))) + 0 + """, + ExecutableOutput("101") + ) + } + + test("execute string substring extracts range") { + fuse( + """ +fun main() -> i32 + _print(_string_substring("hello world", 6, 11)) + 0 + """, + ExecutableOutput("world") + ) + } + test("execute io flat_map with non-unit binding") { fuse( """ @@ -5726,18 +6269,101 @@ fun main() -> i32 ExecutableOutput("43") ) } + test("execute fmap List with trait bound") { + fuse( + """ +type List[A]: + Cons(h: A, t: List[A]) + Nil + +trait Functor[A]: + fun map[B](self, f: A -> B) -> Self[B]; + +impl List[A]: + fun fold_right[A, B](as: List[A], z: B, f: (A, B) -> B) -> B + match as: + Cons(x, xs) => f(x, List::fold_right(xs, z, f)) + Nil => z + + fun fold_left[A, B](l: List[A], acc: B, f: (B, A) -> B) -> B + match l: + Cons(h, t) => List::fold_left(t, f(acc, h), f) + Nil => acc + + fun append[A](l1: List[A], l2: List[A]) -> List[A] + List::fold_right(l1, l2, (h, t) => Cons(h, t)) + + fun sum(l: List[i32]) -> i32 + List::fold_right(l, 0, (acc, b) => acc + b) + + fun product(l: List[i32]) -> i32 + List::fold_left(l, 1, (acc, b) => acc * b) + + fun filter[A](self, f: A -> bool) -> List[A] + List::fold_right(self, Nil[A], (h, t) => { + match f(h): + true => Cons(h, t) + false => t + }) + +impl Functor[A] for List[A]: + fun map[B](self, f: A -> B) -> List[B] + List::fold_right(self, Nil[B], (h, t) => Cons(f(h), t)) + +fun fmap[A, B, T: Functor](f: A -> B, c: T[A]) -> T[B] + c.map(f) + +fun main() -> i32 + let l = Cons(2, Cons(3, Nil)) + let l1 = fmap(v => v + 1, l) + let l2 = Cons(7, Nil) + let l3 = List::append(l1, l2) + let l4 = l3.filter(e => e > 3) + let s = List::sum(l4) + let p = List::product(l4) + _print(int_to_str(s + p)) + 0 + """, + ExecutableOutput("39") + ) + } + + test("execute brackets in string literals survive codegen") { + fuse( + """ +fun main() -> i32 + _print("[a]b[c]") + 0 + """, + ExecutableOutput("[a]b[c]") + ) + } + + test("execute escaped double-quote round-trips to a single quote byte") { + fuse( + """ +fun main() -> i32 + _print("\"") + 0 + """, + ExecutableOutput("\"") + ) + } } object CompilerTests { import Fuse.* sealed trait Output - case class CheckOutput(s: Option[String]) extends Output - case class BuildOutput(s: String) extends Output + case class CheckOutput(s: Option[String], includeStdlib: Boolean = true) + extends Output + case class BuildOutput(s: String, includeStdlib: Boolean = false) + extends Output case class ExecutableOutput( expectedStdout: String, expectedExitCode: Int = 0, - includeStdlib: Boolean = false + includeStdlib: Boolean = false, + args: List[String] = Nil ) extends Output case class ExecutionResult(stdout: String, stderr: String, exitCode: Int) @@ -5765,50 +6391,33 @@ object CompilerTests { Resource.make(acquire)(release) } - private def executeProcess(exePath: Path): IO[ExecutionResult] = { + def execute( + code: String, + includeStdlib: Boolean = false, + args: List[String] = Nil + ): IO[Either[String, ExecutionResult]] = + createTempFuseFile(code).use { fusePath => + Fuse.runFile(fusePath.toString, args, includeStdlib, executeCapture).map { + case Right(result) => Right(result) + case Left(_) => Left("compilation failed: fuse or grin error") + } + } + + def executeCapture(exe: Path, args: List[String]): IO[ExecutionResult] = IO.blocking { val stdout = new StringBuilder val stderr = new StringBuilder - val logger = ProcessLogger( out => stdout.append(out).append("\n"), err => stderr.append(err).append("\n") ) - - val exitCode = Process(exePath.toString).!(logger) - + val exitCode = Process(exe.toString +: args).!(logger) ExecutionResult( stdout.toString.trim, stderr.toString.trim, exitCode ) } - } - - def execute( - code: String, - includeStdlib: Boolean = false - ): IO[Either[String, ExecutionResult]] = { - createTempFuseFile(code).use { fusePath => - for { - buildExitCode <- Fuse.build( - BuildFile(fusePath.toString, includeStdlib) - ) - result <- buildExitCode match { - case ExitCode.Success => - val outPath = - Paths.get( - fusePath.toString.stripSuffix( - s".$FuseFileExtension" - ) + s".$FuseOutputExtension" - ) - executeProcess(outPath).map(Right(_)) - case _ => - IO.pure(Left("compilation failed: fuse or grin error")) - } - } yield result - } - } /** Synchronously load the pre-parsed library module for the test helpers if * the command requests it. Mirrors `Compiler.run`'s load-then-compile flow @@ -5823,8 +6432,12 @@ object CompilerTests { } } - def check(code: String, fileName: String = s"test.$FuseFileExtension") = { - val command = CheckFile(fileName) + def check( + code: String, + fileName: String = s"test.$FuseFileExtension", + includeStdlib: Boolean = true + ) = { + val command = CheckFile(fileName, includeStdlib) loadStdlibSync(command).flatMap(stdlib => Compiler.compile(command, code.trim, fileName, stdlib) ) diff --git a/stdlib/io.fuse b/stdlib/io.fuse index a7a67e9..bf3e109 100644 --- a/stdlib/io.fuse +++ b/stdlib/io.fuse @@ -25,3 +25,20 @@ fun read(path: str) -> IO[str] fun write(path: str, content: str) -> IO[Unit] MkIO(_ => _file_write(path, content)) + +fun get_args() -> IO[List[str]] + let collect_args = (i: i32, acc: List[str]) => { + match i < 0: + true => acc + false => collect_args(i - 1, Cons(_args_get(i), acc)) + } + MkIO(_ => collect_args(_args_count(()) - 1, Nil)) + +fun read_stdin_acc(acc: str) -> str + let line = _read_stdin(()) + match str_len(line) == 0: + true => acc + false => read_stdin_acc(acc + line) + +fun read_stdin() -> IO[str] + MkIO(_ => read_stdin_acc("")) diff --git a/stdlib/list.fuse b/stdlib/list.fuse new file mode 100644 index 0000000..d772403 --- /dev/null +++ b/stdlib/list.fuse @@ -0,0 +1,39 @@ +type List[T]: + Cons(h: T, t: List[T]) + Nil + +impl List[T]: + fun reverse(self) -> List[T] + let iter = (xs: List[T], acc: List[T]) => { + match xs: + Cons(h, t) => iter(t, Cons(h, acc)) + Nil => acc + } + iter(self, Nil[T]) + + fun head_or[T](self, dflt: T) -> T + match self: + Cons(h, _) => h + Nil => dflt + + fun length(self) -> i32 + let iter = (xs: List[T], acc: i32) => { + match xs: + Cons(_, t) => iter(t, acc + 1) + Nil => acc + } + iter(self, 0) + + fun append(self, other: List[T]) -> List[T] + match self: + Cons(h, t) => Cons(h, t.append(other)) + Nil => other + +impl Monad for List[T]: + fun flat_map[B](self, f: T -> List[B]) -> List[B] + match self: + Cons(h, t) => f(h).append(t.flat_map(f)) + Nil => Nil[B] + + fun unit[T](a: T) -> List[T] + Cons(a, Nil) diff --git a/stdlib/option.fuse b/stdlib/option.fuse new file mode 100644 index 0000000..dc3b1a2 --- /dev/null +++ b/stdlib/option.fuse @@ -0,0 +1,37 @@ +type Option[T]: + None + Some(T) + +impl Option[T]: + fun is_some(self) -> bool + match self: + Some(v) => true + _ => false + + fun is_none(self) -> bool + match self: + Some(v) => false + _ => true + + fun get_or_else(self, default: T) -> T + match self: + Some(a) => a + None => default + + fun filter(self, f: T -> bool) -> Option[T] + match self: + Some(a) => { + match f(a): + true => self + _ => None + } + _ => None + +impl Monad for Option[T]: + fun flat_map[B](self, f: T -> Option[B]) -> Option[B] + match self: + Some(a) => f(a) + None => None + + fun unit[T](a: T) -> Option[T] + Some(a) diff --git a/stdlib/string.fuse b/stdlib/string.fuse new file mode 100644 index 0000000..ed59e62 --- /dev/null +++ b/stdlib/string.fuse @@ -0,0 +1,37 @@ +fun char_at(s: str, i: i32) -> i32 + _string_char_at(s, i) + +fun substring(s: str, start: i32, end: i32) -> str + _string_substring(s, start, end) + +fun str_len(s: str) -> i32 + _string_len(s) + +fun is_whitespace(c: i32) -> bool + c == 32 || c == 9 || c == 10 || c == 13 + +fun is_digit(c: i32) -> bool + c >= 48 && c <= 57 + +fun is_lower(c: i32) -> bool + c >= 97 && c <= 122 + +fun is_upper(c: i32) -> bool + c >= 65 && c <= 90 + +fun is_alpha(c: i32) -> bool + is_lower(c) || is_upper(c) + +fun is_alnum(c: i32) -> bool + is_alpha(c) || is_digit(c) + +fun is_ident_start(c: i32) -> bool + is_alpha(c) || c == 95 + +fun is_ident_cont(c: i32) -> bool + is_alnum(c) || c == 95 + +fun repeat(s: str, n: i32) -> str + match n <= 0: + true => "" + false => s + repeat(s, n - 1)