diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c5db424..3844a59 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,5 +24,4 @@ jobs: - uses: actions/checkout@v4 - uses: cachix/install-nix-action@v27 - run: nix build -L .#ocaml - - run: nix build -L .#legacy - run: nix flake check -L diff --git a/docs/language.md b/docs/language.md index 38ffcdc..58f59f9 100644 --- a/docs/language.md +++ b/docs/language.md @@ -71,6 +71,56 @@ pub fn vector_add(fvec3 a, fvec3 b) -> fvec3 { > [!TIP] > When integrating Haven with C, `fvecN` is the equivalent of ([non-standard](https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html)) `typedef float floatN __attribute__((vector_size(sizeof(float)) * N))`. +#### Specialized Vector Functions + +Functions may accept vectors of any concrete dimension using `fvec?`. + +These are not runtime-sized vectors. Instead, the compiler specializes the function at each call site using the +concrete argument types from that call. + +``` +fn vadd(fvec? a, fvec? b) { + @assert a.dim == b.dim, "vector dimensions must match"; + a + b +} +``` + +Inside such a function: + +- `a.dim` is a compile-time property of the specialized vector type +- omitted return types are inferred after specialization +- `@assert` conditions must reduce to compile-time constants after specialization + +If a compile-time assertion fails, the compiler reports both the original condition and the specialized one: + +``` +semantic: error: vector dimensions must match + compile-time assertion failed: a.dim == b.dim + specialized as: 3 == 2 +``` + +### Matrices + +Haven offers `matMxN` matrix types of floating point numbers. + +Like vectors, matrices support specialization holes in function signatures through `mat?`. + +``` +fn mat-width(mat? m) { + m.cols +} + +fn get-mat-row(mat? m, u32 row) { + m[row] +} +``` + +Inside specialized matrix functions: + +- `m.rows` and `m.cols` are compile-time properties +- indexing a concrete `matMxN` yields an `fvecN` +- specialized functions are cloned before LLVM lowering, so hole types do not reach IR + ### Strings The `str` type carries string data. It is essentially a `const char *` under the hood. diff --git a/examples/specialization.hv b/examples/specialization.hv new file mode 100644 index 0000000..cd7e816 --- /dev/null +++ b/examples/specialization.hv @@ -0,0 +1,19 @@ +fn vadd(fvec? a, fvec? b) { + @assert a.dim == b.dim, "vector dimensions must match"; + a + b +} + +fn mat-width(mat? m) { + m.cols +} + +fn get-mat-row(mat? m, u32 row) { + m[row] +} + +pub fn main() -> fvec3 { + let row = get-mat-row(Mat, Vec<4.0, 5.0, 6.0>>, 1); + let sum = vadd(row, Vec<10.0, 20.0, 30.0>); + let cols = as(mat-width(Mat, Vec<4.0, 5.0, 6.0>>)); + sum + Vec +} diff --git a/examples/vec/index.hv b/examples/vec/index.hv index e753b1e..4af904e 100644 --- a/examples/vec/index.hv +++ b/examples/vec/index.hv @@ -1,26 +1,38 @@ pub fn __builtin_sqrtf(float x) -> float intrinsic "llvm.sqrt" float; -fn vadd(fvec3 a, fvec3 b) -> fvec3 { +fn vadd(fvec? a, fvec? b) { + @assert a.dim == b.dim, "vadd requires both vectors to be the same width"; + a + b } -fn vscale(fvec3 a, float b) -> fvec3 { +fn vscale(fvec? a, float b) { a * b } -fn vdot(fvec3 a, fvec3 b) -> float { +fn vdot(fvec? a, fvec? b) -> float { + @assert a.dim == b.dim, "vdot requires both vectors to be the same width"; + let mult = a * b; - mult.x + mult.y + mult.z + let mut accum = 0.0; + + iter 0:(a.dim - 1) i { + accum = accum + mult[i]; + }; + + accum } fn vcross(fvec3 a, fvec3 b) -> fvec3 { let x = a.y * b.z - a.z * b.y; let y = a.z * b.x - a.x * b.z; let z = a.x * b.y - a.y * b.x; + Vec } -impure fn vnorm(fvec3 v) -> fvec3 { +impure fn vnorm(fvec? v) { let denom = 1.0 / __builtin_sqrtf(vdot(v, v)); + v * denom } diff --git a/flake.nix b/flake.nix index a924bd4..aacc0ef 100644 --- a/flake.nix +++ b/flake.nix @@ -42,7 +42,6 @@ checks.default = havenOcaml; checks.ocaml = havenOcaml; - checks.legacy = havenLegacy; devShells.default = pkgs.mkShell { nativeBuildInputs = with pkgs; [ diff --git a/src/bin/haven.ml b/src/bin/haven.ml index 32e3f77..5ffefcc 100644 --- a/src/bin/haven.ml +++ b/src/bin/haven.ml @@ -70,6 +70,7 @@ let collect_pipeline_diagnostics (pipeline : Analysis.Pipeline.result) = pipeline.typing.diagnostics @ pipeline.verify.diagnostics @ pipeline.semantic.diagnostics + @ pipeline.asserts.diagnostics @ pipeline.purity.diagnostics @ pipeline.ownership.diagnostics diff --git a/src/bin/lsp/document_store.ml b/src/bin/lsp/document_store.ml index d16d4b6..df791a8 100644 --- a/src/bin/lsp/document_store.ml +++ b/src/bin/lsp/document_store.ml @@ -85,11 +85,23 @@ let location_from_error_message ~filename message = in Option.bind location_text parse_line_col -let analyze_document uri text = +let get_text_by_path (store : t) path = + Hashtbl.fold + (fun _ (doc : document) acc -> + match acc with + | Some _ -> acc + | None -> + if String.equal (DocumentUri.to_path doc.uri) path then Some doc.text else None) + store None + +let analyze_document (store : t) uri text = let filename = DocumentUri.to_path uri in try let cst = Haven.Parser.parse_string ~filename text in - let pipeline = Analysis.Pipeline.run_cst cst in + let pipeline = + Analysis.Pipeline.run_cst ~import_text_resolver:(get_text_by_path store) + cst + in (Some cst, Some pipeline, None) with | Failure message @@ -111,12 +123,15 @@ let analyze_document uri text = in (None, None, Some parse_error) -let reanalyze doc = - let cst, pipeline, parse_error = analyze_document doc.uri doc.text in +let reanalyze store doc = + let cst, pipeline, parse_error = analyze_document store doc.uri doc.text in doc.cst <- cst; doc.pipeline <- pipeline; doc.parse_error <- parse_error +let reanalyze_all (store : t) = + Hashtbl.iter (fun _ doc -> reanalyze store doc) store + let open_doc (store : t) (td : TextDocumentItem.t) = let uri = td.uri in let doc = @@ -129,10 +144,12 @@ let open_doc (store : t) (td : TextDocumentItem.t) = parse_error = None; } in - reanalyze doc; - Hashtbl.replace store uri doc + Hashtbl.replace store uri doc; + reanalyze_all store -let close_doc (store : t) (uri : DocumentUri.t) = Hashtbl.remove store uri +let close_doc (store : t) (uri : DocumentUri.t) = + Hashtbl.remove store uri; + reanalyze_all store let line_offsets text = let offsets = ref [ 0 ] in @@ -178,7 +195,29 @@ let change_doc (store : t) (d : VersionedTextDocumentIdentifier.t) | Some doc -> doc.version <- Some d.version; List.iter (apply_change doc) evs; - reanalyze doc + reanalyze_all store + +let save_doc (store : t) (id : TextDocumentIdentifier.t) text = + match (Hashtbl.find_opt store id.uri, text) with + | Some doc, Some text -> + doc.text <- text; + reanalyze_all store + | Some _doc, None -> + reanalyze_all store + | None, Some text -> + let doc = + { + uri = id.uri; + version = None; + text; + cst = None; + pipeline = None; + parse_error = None; + } + in + Hashtbl.replace store id.uri doc; + reanalyze_all store + | None, None -> () let get_text (store : t) (uri : DocumentUri.t) : string option = Hashtbl.find_opt store uri |> Option.map (fun d -> d.text) diff --git a/src/bin/lsp/document_symbols.ml b/src/bin/lsp/document_symbols.ml index daad023..be8a401 100644 --- a/src/bin/lsp/document_symbols.ml +++ b/src/bin/lsp/document_symbols.ml @@ -24,12 +24,12 @@ let function_detail (fn : Cst.function_decl) = "fn"; ]) in - let return_type = + let return_suffix = match fn.value.return_type with - | Some ty -> type_text ty - | None -> "void" + | Some ty -> " -> " ^ type_text ty + | None -> "" in - Printf.sprintf "%s (%s) -> %s" prefix (String.concat ", " params) return_type + Printf.sprintf "%s (%s)%s" prefix (String.concat ", " params) return_suffix let variable_detail (decl : Cst.var_decl) = Printf.sprintf "%s%s" diff --git a/src/bin/lsp/haven_lsp.ml b/src/bin/lsp/haven_lsp.ml index 4d23129..9ac6641 100644 --- a/src/bin/lsp/haven_lsp.ml +++ b/src/bin/lsp/haven_lsp.ml @@ -23,6 +23,7 @@ let collect_pipeline_diagnostics (pipeline : Analysis.Pipeline.result) = pipeline.typing.diagnostics @ pipeline.verify.diagnostics @ pipeline.semantic.diagnostics + @ pipeline.asserts.diagnostics @ pipeline.purity.diagnostics @ pipeline.ownership.diagnostics @@ -59,12 +60,21 @@ let publish_diagnostics_params (state : state) (uri : DocumentUri.t) = (PublishDiagnosticsParams.create ~diagnostics:(diagnostics_for_doc doc) ~uri ?version:doc.version ()) +let publish_all_diagnostics_params (state : state) = + Hashtbl.fold + (fun uri _ acc -> + match publish_diagnostics_params state uri with + | None -> acc + | Some params -> params :: acc) + state.docs [] + let server_capabilities () : ServerCapabilities.t = ServerCapabilities.create ~textDocumentSync: (`TextDocumentSyncOptions (TextDocumentSyncOptions.create ~openClose:true - ~change:TextDocumentSyncKind.Incremental ())) + ~change:TextDocumentSyncKind.Incremental + ~save:(`SaveOptions (SaveOptions.create ~includeText:true ())) ())) ~definitionProvider:(`Bool true) ~documentHighlightProvider:(`Bool true) ~documentFormattingProvider:(`Bool true) @@ -98,6 +108,9 @@ let on_did_change (state : state) (doc : VersionedTextDocumentIdentifier.t) (evs : TextDocumentContentChangeEvent.t list) = Document_store.change_doc state.docs doc evs +let on_did_save (state : state) (params : DidSaveTextDocumentParams.t) = + Document_store.save_doc state.docs params.textDocument params.text + let with_doc state uri f = Option.bind (Document_store.get_doc state.docs uri) f @@ -187,19 +200,51 @@ let on_code_lenses (state : state) (uri : DocumentUri.t) = let on_execute_command (_state : state) command = if String.equal command Code_lenses.command_name then Some `Null else None +let parse_cst text uri = + let filename = DocumentUri.to_path uri in + try Some (Haven.Parser.parse_string ~filename text) with _ -> None + +let read_file_text path = + try + let ch = open_in_bin path in + Fun.protect + ~finally:(fun () -> close_in_noerr ch) + (fun () -> + let len = in_channel_length ch in + really_input_string ch len) + |> Option.some + with _ -> None + let format_document (state : state) (uri : DocumentUri.t) : TextEdit.t list option = - let full_range = - let start_pos = { Position.line = 0; character = 0 } in - let end_pos = { Position.line = max_int; character = 0 } in - { Range.start = start_pos; end_ = end_pos } - in - match Document_store.get_cst state.docs uri with - | None -> None - | Some cst -> - let newText = Haven.Cst.Emit.emit_program_to_string cst in - let edit = TextEdit.create ~range:full_range ~newText in - Some [ edit ] + match Document_store.get_doc state.docs uri with + | None -> + let path = DocumentUri.to_path uri in + Option.bind (read_file_text path) (fun text -> + Option.bind (parse_cst text uri) (fun cst -> + let newText = Haven.Cst.Emit.emit_program_to_string cst in + let edit = + TextEdit.create ~range:(Lsp_helpers.full_document_range text) + ~newText + in + Some [ edit ])) + | Some doc -> ( + match doc.cst with + | Some cst -> + let newText = Haven.Cst.Emit.emit_program_to_string cst in + let edit = + TextEdit.create ~range:(Lsp_helpers.full_document_range doc.text) + ~newText + in + Some [ edit ] + | None -> + Option.bind (parse_cst doc.text uri) (fun cst -> + let newText = Haven.Cst.Emit.emit_program_to_string cst in + let edit = + TextEdit.create ~range:(Lsp_helpers.full_document_range doc.text) + ~newText + in + Some [ edit ])) let on_formatting (state : state) (params : DocumentFormattingParams.t) : TextEdit.t list option = diff --git a/src/bin/lsp/hvlsp.ml b/src/bin/lsp/hvlsp.ml index 79e58e2..5e533a6 100644 --- a/src/bin/lsp/hvlsp.ml +++ b/src/bin/lsp/hvlsp.ml @@ -22,20 +22,30 @@ module Server = struct Logs.info (fun m -> m "didOpen: %s (version=%d)" (Lsp.Uri.to_string doc.uri) doc.version); Haven_lsp.on_did_open state doc; - (match Haven_lsp.publish_diagnostics_params state doc.uri with - | None -> Lwt.return_unit - | Some params -> + Lwt_list.iter_s + (fun params -> notify_back#send_notification (Lsp.Server_notification.PublishDiagnostics params)) + (Haven_lsp.publish_all_diagnostics_params state) method on_notif_doc_did_close ~notify_back (id : Lsp.Types.TextDocumentIdentifier.t) = Logs.info (fun m -> m "didClose: %s" (Lsp.Uri.to_string id.uri)); Haven_lsp.on_did_close state id; - notify_back#send_notification - (Lsp.Server_notification.PublishDiagnostics - (Lsp.Types.PublishDiagnosticsParams.create ~uri:id.uri ~diagnostics:[] - ())) + let clear_closed = + notify_back#send_notification + (Lsp.Server_notification.PublishDiagnostics + (Lsp.Types.PublishDiagnosticsParams.create ~uri:id.uri + ~diagnostics:[] ())) + in + let publish_open = + Lwt_list.iter_s + (fun params -> + notify_back#send_notification + (Lsp.Server_notification.PublishDiagnostics params)) + (Haven_lsp.publish_all_diagnostics_params state) + in + Lwt.bind clear_closed (fun () -> publish_open) method on_notif_doc_did_change ~notify_back (id : Lsp.Types.VersionedTextDocumentIdentifier.t) @@ -45,11 +55,22 @@ module Server = struct m "didChange: %s (version=%d, %d changes)" (Lsp.Uri.to_string id.uri) id.version (List.length changes)); Haven_lsp.on_did_change state id changes; - (match Haven_lsp.publish_diagnostics_params state id.uri with - | None -> Lwt.return_unit - | Some params -> + Lwt_list.iter_s + (fun params -> + notify_back#send_notification + (Lsp.Server_notification.PublishDiagnostics params)) + (Haven_lsp.publish_all_diagnostics_params state) + + method! on_notif_doc_did_save ~notify_back + (params : Lsp.Types.DidSaveTextDocumentParams.t) = + Logs.info (fun m -> + m "didSave: %s" (Lsp.Uri.to_string params.textDocument.uri)); + Haven_lsp.on_did_save state params; + Lwt_list.iter_s + (fun params -> notify_back#send_notification (Lsp.Server_notification.PublishDiagnostics params)) + (Haven_lsp.publish_all_diagnostics_params state) method! on_req_hover ~notify_back:_ ~id:_ ~uri ~pos ~workDoneToken (_doc_state : Linol_lwt.Jsonrpc2.doc_state) = diff --git a/src/bin/lsp/inlay_hints.ml b/src/bin/lsp/inlay_hints.ml index 643acad..100619c 100644 --- a/src/bin/lsp/inlay_hints.ml +++ b/src/bin/lsp/inlay_hints.ml @@ -84,6 +84,8 @@ and walk_statement typing type_env query_range acc (stmt : Core.statement) = match stmt.value with | Core.Expression expr -> walk_expression typing type_env query_range acc expr + | CompileAssert compile_assert -> + walk_expression typing type_env query_range acc compile_assert.value.cond | Let binding -> let acc = match (binding.value.ty, binding_annotation typing binding) with diff --git a/src/bin/lsp/lsp_helpers.ml b/src/bin/lsp/lsp_helpers.ml index 7ecfc7d..bb0898b 100644 --- a/src/bin/lsp/lsp_helpers.ml +++ b/src/bin/lsp/lsp_helpers.ml @@ -16,6 +16,16 @@ let loc_to_range (loc : Haven_core.Loc.t) = Range.create ~start:(position_of_lex_position loc.start_pos) ~end_:(position_of_lex_position loc.end_pos) +let end_position_of_text text = + let offsets = line_offsets text in + let last_line = max 0 (Array.length offsets - 1) in + let bol = offsets.(last_line) in + Position.create ~line:last_line ~character:(String.length text - bol) + +let full_document_range text = + Range.create ~start:(Position.create ~line:0 ~character:0) + ~end_:(end_position_of_text text) + let lex_position_of_lsp_position ~filename ~text (position : Position.t) = let offsets = line_offsets text in let max_line = max 0 (Array.length offsets - 1) in diff --git a/src/bin/lsp/semantic_tokens.ml b/src/bin/lsp/semantic_tokens.ml index 022ab74..a31c8f1 100644 --- a/src/bin/lsp/semantic_tokens.ml +++ b/src/bin/lsp/semantic_tokens.ml @@ -219,8 +219,12 @@ let collect_lexical_tokens (parsed : CST.parsed_program) = (loc_of_raw_tok { tok = entry.token; startp = entry.startp; endp = entry.endp }) | Lexer.Raw.Ident _ -> acc - | Lexer.Raw.Numeric_type _ | Vec_type _ | Mat_type _ | Float_type - | Void_type | Str_type -> + | Lexer.Raw.Directive Assert -> + add_token_for_loc acc ~token_type:"keyword" + (loc_of_raw_tok + { tok = entry.token; startp = entry.startp; endp = entry.endp }) + | Lexer.Raw.Numeric_type _ | Vec_type _ | Mat_type _ | Vec_hole_type + | Mat_hole_type | Float_type | Void_type | Str_type -> add_token_for_loc acc ~token_type:"type" (loc_of_raw_tok { tok = entry.token; startp = entry.startp; endp = entry.endp }) diff --git a/src/bin/lsp/symbol_resolution.ml b/src/bin/lsp/symbol_resolution.ml index 42d2a01..722608a 100644 --- a/src/bin/lsp/symbol_resolution.ml +++ b/src/bin/lsp/symbol_resolution.ml @@ -30,6 +30,7 @@ type state = { typing : Analysis.typing_result; type_env : Analysis.type_env; root_env : binding String_map.t; + function_decls : Core.function_decl String_map.t; type_decls : Core.type_decl String_map.t; mutable best : candidate option; } @@ -90,6 +91,11 @@ let expr_annotation typing (expr : Core.expression) = let binding_annotation typing (binding : Core.let_stmt) = Hashtbl.find_opt typing.Analysis.annotations.bindings (Analysis.binding_id binding) +let expr_resolved_type typing (expr : Core.expression) = + match expr_annotation typing expr with + | None -> None + | Some annotation -> annotation.resolved_type + let type_summary loc inferred resolved = match (inferred, resolved) with | Some inferred, Some resolved -> @@ -122,13 +128,37 @@ let function_signature (fn : Core.function_decl) = "fn"; ]) in - let return_type = + let return_suffix = match fn.value.return_type with - | Some ty -> format_core_type ty - | None -> "void" + | Some ty -> " -> " ^ format_core_type ty + | None -> "" in - Printf.sprintf "%s %s(%s) -> %s" prefix fn.value.name.value - (String.concat ", " params) return_type + Printf.sprintf "%s %s(%s)%s" prefix fn.value.name.value + (String.concat ", " params) return_suffix + +let function_signature_with_types ~name ~public ~impure ~vararg ~params + ~return_type = + let params = + let params = + List.map + (fun (param_name, ty) -> + Printf.sprintf "%s: %s" param_name (format_core_type ty)) + params + in + if vararg then params @ [ "..." ] else params + in + let prefix = + String.concat " " + (List.filter + (fun part -> not (String.equal part "")) + [ if public then "pub" else ""; if impure then "impure" else ""; "fn" ]) + in + let return_suffix = + match return_type with + | Some ty -> " -> " ^ format_core_type ty + | None -> "" + in + Printf.sprintf "%s %s(%s)%s" prefix name (String.concat ", " params) return_suffix let type_decl_summary (decl : Core.type_decl) = match decl.value.data with @@ -266,6 +296,21 @@ let type_decls (typing : Analysis.typing_result) = decls) String_map.empty typing.program.program.value.decls +let function_decls (typing : Analysis.typing_result) = + List.fold_left + (fun decls (decl : Core.top_decl) -> + match decl.value with + | Core.FDecl fn -> + String_map.add fn.value.name.value fn decls + | Core.Foreign foreign -> + List.fold_left + (fun decls (fn : Core.function_decl) -> + String_map.add fn.value.name.value fn decls) + decls foreign.value.decls + | Core.TDecl _ | Core.VDecl _ | Core.Import _ | Core.CImport _ -> + decls) + String_map.empty typing.program.program.value.decls + let maybe_pick_type_decl state id priority = Option.iter (fun decl -> @@ -287,7 +332,8 @@ let rec walk_type state (ty : Core.haven_type) = List.iter (walk_type state) templ.value.inner | Core.CustomType custom -> maybe_pick_type_decl state custom.name 30 - | NumericType _ | VecType _ | MatrixType _ | FloatType | VoidType | StringType -> + | NumericType _ | VecType _ | MatrixType _ | VecHoleType | MatrixHoleType + | FloatType | VoidType | StringType -> () and walk_match_pattern state (pattern : Core.match_pattern) = @@ -339,20 +385,74 @@ and bind_pattern state env scrutinee_resolved (pattern : Core.match_pattern) = env and enum_literal_hover state expr (enum_lit : Core.enum_literal) = - let expr_resolved_type = - match expr_annotation state.typing expr with - | None -> None - | Some annotation -> annotation.resolved_type - in maybe_pick_type_decl state enum_lit.value.enum_name 35; Option.iter (fun (variant, _inner_ty) -> maybe_pick_binding state enum_lit.value.enum_variant.loc 35 (make_binding (enum_variant_contents variant) variant.value.name.loc)) - (Option.bind expr_resolved_type (fun resolved -> + (Option.bind (expr_resolved_type state.typing expr) (fun resolved -> Analysis.lookup_enum_variant state.type_env expr.loc resolved enum_lit.value.enum_variant.value)) +and specialized_call_hover state (expr : Core.expression) (call : Core.call) = + match call.value.target.value with + | Core.Identifier id -> ( + match String_map.find_opt id.value state.function_decls with + | None -> () + | Some fn -> + if + (not (Analysis.function_has_specialization_param fn)) + || List.length fn.value.params.value.params <> List.length call.value.params + then () + else + let specialized_params = + List.map2 + (fun (param : Core.param) (arg : Core.expression) -> + if Analysis.type_has_specialization_hole param.value.ty then + Option.map + (fun resolved -> + ( param.value.name.value, + Analysis.core_type_of_resolved_ty arg.loc resolved )) + (expr_resolved_type state.typing arg) + else Some (param.value.name.value, param.value.ty)) + fn.value.params.value.params call.value.params + in + match + List.fold_right + (fun item acc -> + match (item, acc) with + | Some item, Some acc -> Some (item :: acc) + | _ -> None) + specialized_params (Some []) + with + | None -> () + | Some params -> + let specialized = + function_signature_with_types ~name:fn.value.name.value + ~public:fn.value.public ~impure:fn.value.impure + ~vararg:fn.value.vararg ~params + ~return_type: + (Option.map + (Analysis.core_type_of_resolved_ty expr.loc) + (expr_resolved_type state.typing expr)) + in + let original = function_signature fn in + let hover_text = + if String.equal specialized original then hover_block specialized + else + hover_block + (Printf.sprintf "%s\nspecialized from %s" specialized + original) + in + maybe_pick state + { + loc = id.loc; + priority = 45; + hover_text = Some hover_text; + definition_loc = Some fn.value.name.loc; + }) + | _ -> () + and walk_expression state env (expr : Core.expression) = Option.iter (fun contents -> maybe_pick state (hover_candidate expr.loc 10 contents)) @@ -385,17 +485,14 @@ and walk_expression state env (expr : Core.expression) = walk_type state ty | Match match_expr -> walk_expression state env match_expr.value.expr; - let scrutinee_resolved = - match expr_annotation state.typing match_expr.value.expr with - | Some annotation -> annotation.resolved_type - | None -> None - in + let scrutinee_resolved = expr_resolved_type state.typing match_expr.value.expr in List.iter (fun (arm : Core.match_arm) -> let arm_env = bind_pattern state env scrutinee_resolved arm.value.pattern in walk_expression state arm_env arm.value.expr) match_expr.value.arms | Call call -> + specialized_call_hover state expr call; walk_expression state env call.value.target; List.iter (walk_expression state env) call.value.params | Index index -> @@ -415,6 +512,9 @@ and walk_statement state env (stmt : Core.statement) = | Core.Expression expr -> walk_expression state env expr; env + | CompileAssert compile_assert -> + walk_expression state env compile_assert.value.cond; + env | Let binding -> let binding_info = make_binding (binding_contents state.typing binding) binding.value.name.loc @@ -514,6 +614,7 @@ let resolve_at (typing : Analysis.typing_result) position : resolution option = typing; type_env = Analysis.type_env_of_program typing.program.program; root_env = root_bindings typing; + function_decls = function_decls typing; type_decls = type_decls typing; best = None; } @@ -601,7 +702,8 @@ let rec highlight_type state (ty : Core.haven_type) = List.iter (highlight_type state) templ.value.inner | Core.CustomType custom -> maybe_add_type_decl_highlight state `Text custom.name - | NumericType _ | VecType _ | MatrixType _ | FloatType | VoidType | StringType -> + | NumericType _ | VecType _ | MatrixType _ | VecHoleType | MatrixHoleType + | FloatType | VoidType | StringType -> () let highlight_match_pattern state env scrutinee_resolved @@ -715,6 +817,9 @@ and highlight_statement state env (stmt : Core.statement) = | Core.Expression expr -> highlight_expression state env expr; env + | CompileAssert compile_assert -> + highlight_expression state env compile_assert.value.cond; + env | Let binding -> let binding_info = make_binding (binding_contents state.typing binding) binding.value.name.loc diff --git a/src/lib/ast/analysis.ml b/src/lib/ast/analysis.ml index 2a62b8c..e21d998 100644 --- a/src/lib/ast/analysis.ml +++ b/src/lib/ast/analysis.ml @@ -3,6 +3,8 @@ include Analysis_types module Typing = Analysis_typing.Typing module Verify = Analysis_verify.Verify module Semantic = Analysis_semantic.Semantic +module Assert = Analysis_asserts.Assert +module Specialize = Analysis_specialize.Specialize module ConstantFold = Analysis_cfold.ConstantFold module Purity = Analysis_purity.Purity module Cleanup = Analysis_cleanup.Cleanup @@ -14,24 +16,86 @@ module Pipeline = struct typing : typing_result; verify : verify_result; semantic : semantic_result; + asserts : semantic_result; purity : purity_result; ownership : ownership_result; cfold : Core.parsed_program; cleaned : Core.parsed_program; } - let run_core core = + let run_analyses ?(check_asserts = true) core = let typing = Typing.run core in - let verify : verify_result = Verify.run typing in - let semantic : semantic_result = Semantic.run typing in - let purity : purity_result = Purity.run typing in - let ownership : ownership_result = Ownership.run typing in - let cfold = ConstantFold.run typing in - let cleaned = Cleanup.run typing in - { core; typing; verify; semantic; purity; ownership; cfold; cleaned } + let assert_result = + if check_asserts then Assert.run typing + else ({ program = typing.program; diagnostics = [] } : Assert.result) + in + let asserted_typed = { typing with program = assert_result.program } in + let assert_failed = + List.exists + (fun (diagnostic : diagnostic) -> diagnostic.level = Error) + assert_result.diagnostics + in + let verify : verify_result = + if assert_failed then { diagnostics = [] } else Verify.run asserted_typed + in + let semantic : semantic_result = + if assert_failed then { diagnostics = [] } else Semantic.run asserted_typed + in + let purity : purity_result = + if assert_failed then { diagnostics = [] } else Purity.run asserted_typed + in + let ownership : ownership_result = + if assert_failed then + { actions = []; index = Ownership.make_index (); diagnostics = [] } + else Ownership.run asserted_typed + in + let cfold = + if assert_failed then assert_result.program + else ConstantFold.run ~program:assert_result.program asserted_typed + in + let cleaned = + if assert_failed then assert_result.program + else Cleanup.run ~program:cfold asserted_typed + in + { + core; + typing = asserted_typed; + verify; + semantic; + asserts = { diagnostics = assert_result.diagnostics }; + purity; + ownership; + cfold; + cleaned; + } + + let has_errors diagnostics = + List.exists (fun (diagnostic : diagnostic) -> diagnostic.level = Error) diagnostics + + let analysis_diagnostics result = + result.typing.diagnostics @ result.verify.diagnostics @ result.semantic.diagnostics + @ result.asserts.diagnostics @ result.purity.diagnostics @ result.ownership.diagnostics - let run_cst ?(search_dirs = []) ?sysroot parsed = - let expanded = Imports.expand_cst ~search_dirs ?sysroot parsed in + let run_core core = + let initial = run_analyses ~check_asserts:false core in + if has_errors (analysis_diagnostics initial) then initial + else + let specialized = Specialize.run initial.typing in + if has_errors specialized.diagnostics then + { + initial with + typing = + { + initial.typing with + diagnostics = initial.typing.diagnostics @ specialized.diagnostics; + }; + } + else run_analyses specialized.program + + let run_cst ?(search_dirs = []) ?sysroot ?import_text_resolver parsed = + let expanded = + Imports.expand_cst ~search_dirs ?sysroot ?import_text_resolver parsed + in let result = run_core (Convert.core_of_expanded_cst expanded.parsed) in let typing = { diff --git a/src/lib/ast/analysis_asserts.ml b/src/lib/ast/analysis_asserts.ml new file mode 100644 index 0000000..e09b443 --- /dev/null +++ b/src/lib/ast/analysis_asserts.ml @@ -0,0 +1,401 @@ +open Analysis_types + +module ConstantFold = Analysis_cfold.ConstantFold + +module Assert = struct + type result = { + program : Core.parsed_program; + diagnostics : diagnostic list; + } + + type state = { + typed : typing_result; + mutable diagnostics_rev : diagnostic list; + } + + let expr_annotation state (expr : Core.expression) = + Hashtbl.find_opt state.typed.annotations.exprs (expr_id expr) + + let add_diagnostic state loc message = + state.diagnostics_rev <- + { category = Semantic; level = Error; loc; message } :: state.diagnostics_rev + + let exact_integer_of_annotation (ann : expr_annotation) = + Option.bind ann.metavar.integer (fun integer -> integer.exact_value) + + let exact_integer state (expr : Core.expression) = + Option.bind (expr_annotation state expr) exact_integer_of_annotation + + let binary_op_string = function + | Core.Add -> "+" + | Core.Subtract -> "-" + | Core.Multiply -> "*" + | Core.Divide -> "/" + | Core.Modulo -> "%" + | Core.LeftShift -> "<<" + | Core.RightShift -> ">>" + | Core.IsEqual -> "==" + | Core.NotEqual -> "!=" + | Core.LessThan -> "<" + | Core.LessThanOrEqual -> "<=" + | Core.GreaterThan -> ">" + | Core.GreaterThanOrEqual -> ">=" + | Core.BitwiseAnd -> "&" + | Core.BitwiseXor -> "^" + | Core.BitwiseOr -> "|" + | Core.LogicAnd -> "&&" + | Core.LogicOr -> "||" + + let binary_precedence = function + | Core.LogicOr -> 1 + | Core.LogicAnd -> 2 + | Core.BitwiseOr -> 3 + | Core.BitwiseXor -> 4 + | Core.BitwiseAnd -> 5 + | Core.IsEqual | Core.NotEqual -> 6 + | Core.LessThan + | Core.LessThanOrEqual + | Core.GreaterThan + | Core.GreaterThanOrEqual -> + 7 + | Core.LeftShift | Core.RightShift -> 8 + | Core.Add | Core.Subtract -> 9 + | Core.Multiply | Core.Divide | Core.Modulo -> 10 + + let render_literal (literal : Core.literal) = + match literal.value with + | Core.Integer value -> string_of_int value + | Core.Bool value -> if value then "true" else "false" + | Core.Float value -> + let rendered = string_of_float value in + if String.contains rendered '.' then rendered else rendered ^ ".0" + | Core.String value -> Printf.sprintf "%S" value + | Core.Char value -> Printf.sprintf "%C" value + | Core.Vector _ -> "Vec<...>" + | Core.Matrix _ -> "Mat<...>" + | Core.Enum enum -> enum.value.enum_variant.value + + let rec render_expression ?(ctx_prec = 0) (expr : Core.expression) = + let self_prec = + match expr.value with + | Core.Binary binary -> binary_precedence binary.value.op + | Core.Unary _ -> 11 + | Core.Call _ | Core.Index _ | Core.Field _ -> 12 + | _ -> 13 + in + let rendered = + match expr.value with + | Core.Literal literal -> render_literal literal + | Core.Identifier id -> id.value + | Core.Binary binary -> + let prec = binary_precedence binary.value.op in + Printf.sprintf "%s %s %s" + (render_expression ~ctx_prec:prec binary.value.left) + (binary_op_string binary.value.op) + (render_expression ~ctx_prec:(prec + 1) binary.value.right) + | Core.Unary unary -> + let op = + match unary.value.op with + | Core.Not -> "!" + | Core.Negate -> "-" + | Core.Complement -> "~" + in + op ^ render_expression ~ctx_prec:11 unary.value.inner + | Core.ToBool inner -> + Printf.sprintf "bool(%s)" (render_expression inner) + | Core.Call call -> + Printf.sprintf "%s(%s)" + (render_expression ~ctx_prec:12 call.value.target) + (String.concat ", " (List.map render_expression call.value.params)) + | Core.Index index -> + Printf.sprintf "%s[%s]" + (render_expression ~ctx_prec:12 index.value.target) + (render_expression index.value.index) + | Core.Field field -> + Printf.sprintf "%s%s%s" + (render_expression ~ctx_prec:12 field.value.target) + (if field.value.arrow then "->" else ".") + field.value.field.value + | Core.As cast -> + Printf.sprintf "as<...>(%s)" (render_expression cast.value.inner) + | Core.SizeExpr inner -> + Printf.sprintf "size(%s)" (render_expression inner) + | Core.Block _ -> "{ ... }" + | Core.Initializer _ -> "{ ... }" + | Core.Match _ -> "match { ... }" + | Core.BoxExpr inner -> "box " ^ render_expression ~ctx_prec:11 inner + | Core.BoxType _ -> "box " + | Core.Unbox inner -> "unbox " ^ render_expression ~ctx_prec:11 inner + | Core.Ref inner -> "ref " ^ render_expression ~ctx_prec:11 inner + | Core.Load inner -> "load " ^ render_expression ~ctx_prec:11 inner + | Core.Assign write -> + Printf.sprintf "%s = %s" (render_expression write.value.target) + (render_expression write.value.value) + | Core.Mutate write -> + Printf.sprintf "%s := %s" (render_expression write.value.target) + (render_expression write.value.value) + | Core.SizeType _ -> "size" + | Core.Nil -> "nil" + in + if self_prec < ctx_prec then "(" ^ rendered ^ ")" else rendered + + let rec render_specialized_expression ?(ctx_prec = 0) state (expr : Core.expression) = + match expr.value with + | Core.Field _ -> ( + match exact_integer state expr with + | Some value -> string_of_int value + | None -> render_expression ~ctx_prec expr) + | Core.Binary binary -> + let prec = binary_precedence binary.value.op in + let rendered = + Printf.sprintf "%s %s %s" + (render_specialized_expression ~ctx_prec:prec state binary.value.left) + (binary_op_string binary.value.op) + (render_specialized_expression ~ctx_prec:(prec + 1) state binary.value.right) + in + if prec < ctx_prec then "(" ^ rendered ^ ")" else rendered + | Core.Unary unary -> + let op = + match unary.value.op with + | Core.Not -> "!" + | Core.Negate -> "-" + | Core.Complement -> "~" + in + let rendered = op ^ render_specialized_expression ~ctx_prec:11 state unary.value.inner in + if 11 < ctx_prec then "(" ^ rendered ^ ")" else rendered + | Core.Call call -> + Printf.sprintf "%s(%s)" + (render_specialized_expression ~ctx_prec:12 state call.value.target) + (String.concat ", " + (List.map + (fun expr -> render_specialized_expression state expr) + call.value.params)) + | Core.Index index -> + Printf.sprintf "%s[%s]" + (render_specialized_expression ~ctx_prec:12 state index.value.target) + (render_specialized_expression state index.value.index) + | Core.ToBool inner -> + Printf.sprintf "bool(%s)" (render_specialized_expression state inner) + | Core.As cast -> + Printf.sprintf "as<...>(%s)" + (render_specialized_expression state cast.value.inner) + | Core.SizeExpr inner -> + Printf.sprintf "size(%s)" (render_specialized_expression state inner) + | Core.BoxExpr inner -> + "box " ^ render_specialized_expression ~ctx_prec:11 state inner + | Core.Unbox inner -> + "unbox " ^ render_specialized_expression ~ctx_prec:11 state inner + | Core.Ref inner -> + "ref " ^ render_specialized_expression ~ctx_prec:11 state inner + | Core.Load inner -> + "load " ^ render_specialized_expression ~ctx_prec:11 state inner + | Core.Assign write -> + Printf.sprintf "%s = %s" + (render_specialized_expression state write.value.target) + (render_specialized_expression state write.value.value) + | Core.Mutate write -> + Printf.sprintf "%s := %s" + (render_specialized_expression state write.value.target) + (render_specialized_expression state write.value.value) + | (Core.Literal _ | Core.Identifier _ | Core.Block _ | Core.Initializer _ | Core.Match _ + | Core.BoxType _ | Core.SizeType _ | Core.Nil) -> + render_expression ~ctx_prec expr + + let assert_context (compile_assert : Core.compile_assert) suffix = + Printf.sprintf "compile-time assertion '%s' %s" + (render_expression compile_assert.value.cond) + suffix + + let assert_failure_message state (compile_assert : Core.compile_assert) message = + let message = + if String.equal message "" then "compile-time assertion failed" else message + in + let source = render_expression compile_assert.value.cond in + let specialized = render_specialized_expression state compile_assert.value.cond in + if String.equal source specialized then + Printf.sprintf "%s\n compile-time assertion failed: %s" message source + else + Printf.sprintf + "%s\n compile-time assertion failed: %s\n specialized as: %s" + message source specialized + + let constant_of_annotation (ann : expr_annotation) = + match ann.metavar.constant with + | Some constant -> Some constant + | None -> ( + match (ann.resolved_type, ann.metavar.integer) with + | Some (ResolvedInt _), Some { exact_value = Some value; _ } -> + Some (ConstantInt value) + | _ -> None) + + let rec constant_of_expr state (expr : Core.expression) = + match expr_annotation state expr with + | Some ann -> ( + match constant_of_annotation ann with + | Some constant -> Some constant + | None -> constant_of_expr_desc state expr) + | None -> constant_of_expr_desc state expr + + and constant_of_expr_desc state (expr : Core.expression) = + match expr.value with + | Core.Literal literal -> ConstantFold.constant_of_literal literal + | Core.Unary unary -> + Option.bind (constant_of_expr state unary.value.inner) (fun inner -> + ConstantFold.fold_unary unary.value.op inner) + | Core.Binary binary -> + Option.bind (constant_of_expr state binary.value.left) (fun left -> + Option.bind (constant_of_expr state binary.value.right) (fun right -> + ConstantFold.fold_binary binary.value.op left right)) + | Core.ToBool inner -> + Option.bind (constant_of_expr state inner) ConstantFold.truthy_of_constant + |> Option.map (fun value -> ConstantBool value) + | Core.Block block when block.value.statements = [] -> + Option.bind block.value.result (constant_of_expr state) + | _ -> None + + let assert_message (compile_assert : Core.compile_assert) fallback = + let message = compile_assert.value.message.value in + if String.equal message "" then assert_context compile_assert fallback + else Printf.sprintf "%s (%s)" message (assert_context compile_assert fallback) + + let eval_assert_condition state (compile_assert : Core.compile_assert) = + match constant_of_expr state compile_assert.value.cond with + | Some constant -> ( + match ConstantFold.truthy_of_constant constant with + | Some true -> true + | Some false -> + add_diagnostic state compile_assert.loc + (assert_failure_message state compile_assert + compile_assert.value.message.value); + false + | None -> + add_diagnostic state compile_assert.loc + (assert_message compile_assert + "must be a scalar constant"); + false) + | None -> + add_diagnostic state compile_assert.loc + (assert_message compile_assert "must be constant"); + false + + let rec rewrite_statement state (stmt : Core.statement) = + match stmt.value with + | Core.Expression _ + | Core.Return _ + | Core.Defer _ + | Core.Let _ + | Core.Break + | Core.Continue -> + [ stmt ] + | Core.CompileAssert compile_assert -> + ignore (eval_assert_condition state compile_assert); + [] + | Core.Loop loop -> + [ + { + stmt with + value = + Core.Loop + { + loop with + value = + { + loop.value with + init = List.concat_map (rewrite_statement state) loop.value.init; + body = rewrite_block state loop.value.body; + step = List.concat_map (rewrite_statement state) loop.value.step; + }; + }; + }; + ] + + and rewrite_statements state (statements : Core.statement list) = + match statements with + | [] -> [] + | stmt :: rest -> ( + match stmt.value with + | Core.CompileAssert compile_assert -> + if eval_assert_condition state compile_assert then rewrite_statements state rest + else [] + | _ -> + let stmt' = rewrite_statement state stmt in + stmt' @ rewrite_statements state rest) + + and rewrite_block state (block : Core.block) = + { + block with + value = + { + block.value with + statements = rewrite_statements state block.value.statements; + }; + } + + let rewrite_decl state (decl : Core.top_decl) = + let value = + match decl.value with + | Core.FDecl fn -> + Core.FDecl + { + fn with + value = + { + fn.value with + definition = Option.map (rewrite_block state) fn.value.definition; + }; + } + | Core.Foreign foreign -> + Core.Foreign + { + foreign with + value = + { + foreign.value with + decls = + List.map + (fun (fn : Core.function_decl) -> + { + fn with + value = + { + fn.value with + definition = Option.map (rewrite_block state) fn.value.definition; + }; + }) + foreign.value.decls; + }; + } + | Core.VDecl binding -> + Core.VDecl + { + binding with + value = + { + binding.value with + init_expr = Option.map (fun expr -> expr) binding.value.init_expr; + }; + } + | (Core.TDecl _ | Core.Import _ | Core.CImport _) as value -> value + in + { decl with value } + + let run (typed : typing_result) = + let state = { typed; diagnostics_rev = [] } in + let program = + { + Core.program = + { + typed.program.program with + value = + { + Core.decls = + List.map (rewrite_decl state) typed.program.program.value.decls; + }; + }; + } + in + { + program; + diagnostics = List.rev state.diagnostics_rev; + } +end diff --git a/src/lib/ast/analysis_cfold.ml b/src/lib/ast/analysis_cfold.ml index 200b399..312a489 100644 --- a/src/lib/ast/analysis_cfold.ml +++ b/src/lib/ast/analysis_cfold.ml @@ -323,6 +323,16 @@ module ConstantFold = struct let value = match stmt.value with | Core.Expression expr -> Core.Expression (fold_expression expr) + | Core.CompileAssert compile_assert -> + Core.CompileAssert + { + compile_assert with + value = + { + compile_assert.value with + cond = fold_expression compile_assert.value.cond; + }; + } | Core.Return expr -> Core.Return (Option.map fold_expression expr) | Core.Defer expr -> Core.Defer (fold_expression expr) | Core.Let binding -> @@ -385,12 +395,13 @@ module ConstantFold = struct | ({ loc; value = (Core.TDecl _ | Core.Import _ | Core.CImport _) as value } : Core.top_decl) -> ({ loc; value } : Core.top_decl) - let run (typed : typing_result) = + let run ?program (typed : typing_result) = + let program = Option.value ~default:typed.program program in { Core.program = { - loc = typed.program.program.loc; - value = { Core.decls = List.map fold_top_decl typed.program.program.value.decls }; + loc = program.program.loc; + value = { Core.decls = List.map fold_top_decl program.program.value.decls }; }; } end diff --git a/src/lib/ast/analysis_cleanup.ml b/src/lib/ast/analysis_cleanup.ml index b12b7c7..b0d50fb 100644 --- a/src/lib/ast/analysis_cleanup.ml +++ b/src/lib/ast/analysis_cleanup.ml @@ -132,6 +132,16 @@ module Cleanup = struct let value = match stmt.value with | Core.Expression expr -> Core.Expression (clean_expression typed expr) + | Core.CompileAssert compile_assert -> + Core.CompileAssert + { + compile_assert with + value = + { + compile_assert.value with + cond = clean_expression typed compile_assert.value.cond; + }; + } | Core.Return expr -> Core.Return (Option.map (clean_expression typed) expr) | Core.Defer expr -> Core.Defer (clean_expression typed expr) | Core.Let binding -> @@ -220,13 +230,14 @@ module Cleanup = struct in { decl with value } - let run typed = + let run ?program typed = + let program = Option.value ~default:typed.program program in { Core.program = { - typed.program.program with + program.program with value = - { Core.decls = List.map (clean_decl typed) typed.program.program.value.decls }; + { Core.decls = List.map (clean_decl typed) program.program.value.decls }; }; } end diff --git a/src/lib/ast/analysis_ownership.ml b/src/lib/ast/analysis_ownership.ml index f809aa5..24671c9 100644 --- a/src/lib/ast/analysis_ownership.ml +++ b/src/lib/ast/analysis_ownership.ml @@ -330,6 +330,9 @@ module Ownership = struct | Core.Expression expr -> visit_expression state scopes expr; scopes + | Core.CompileAssert compile_assert -> + visit_expression state scopes compile_assert.value.cond; + scopes | Core.Return expr -> Option.iter (visit_expression state scopes) expr; (match (return_expected, expr) with diff --git a/src/lib/ast/analysis_purity.ml b/src/lib/ast/analysis_purity.ml index 81ce65c..99479a5 100644 --- a/src/lib/ast/analysis_purity.ml +++ b/src/lib/ast/analysis_purity.ml @@ -184,6 +184,9 @@ module Purity = struct | Core.Expression expr -> visit_expression state current env expr; env + | Core.CompileAssert compile_assert -> + visit_expression state current env compile_assert.value.cond; + env | Core.Return expr -> Option.iter (visit_expression state current env) expr; env diff --git a/src/lib/ast/analysis_semantic.ml b/src/lib/ast/analysis_semantic.ml index 8ebbbb2..2c476f8 100644 --- a/src/lib/ast/analysis_semantic.ml +++ b/src/lib/ast/analysis_semantic.ml @@ -9,6 +9,7 @@ module Semantic = struct typed : typing_result; mutable diagnostics_rev : diagnostic list; type_env : type_env; + functions : Core.function_decl String_map.t; } let add_diagnostic_with_category state category level loc message = @@ -34,14 +35,19 @@ module Semantic = struct let expr_annotation state expr = Hashtbl.find_opt state.typed.annotations.exprs (expr_id expr) - let initial_scope typed = + let lookup_function state name = String_map.find_opt name state.functions + + let initial_scope (typed : typing_result) = + let type_env = type_env_of_program typed.program.program in let add_decl scope (decl : Core.top_decl) = match decl.value with | Core.FDecl fn -> String_map.add fn.value.name.value { inferred_type = Some (Typing.function_type_of_decl fn); - resolved_type = None; + resolved_type = + resolve_core_type type_env [] [] fn.loc + (Typing.function_type_of_decl fn); metavar = metavar_of_type (Typing.function_type_of_decl fn); is_mutable = false; } @@ -61,7 +67,9 @@ module Semantic = struct String_map.add fn.value.name.value { inferred_type = Some (Typing.function_type_of_decl fn); - resolved_type = None; + resolved_type = + resolve_core_type type_env [] [] fn.loc + (Typing.function_type_of_decl fn); metavar = metavar_of_type (Typing.function_type_of_decl fn); is_mutable = false; } @@ -71,6 +79,18 @@ module Semantic = struct in List.fold_left add_decl String_map.empty typed.program.program.value.decls + let collect_functions (program : Core.program) = + let add_fn map (fn : Core.function_decl) = + String_map.add fn.value.name.value fn map + in + List.fold_left + (fun map (decl : Core.top_decl) -> + match decl.value with + | Core.FDecl fn -> add_fn map fn + | Core.Foreign foreign -> List.fold_left add_fn map foreign.value.decls + | Core.VDecl _ | Core.TDecl _ | Core.Import _ | Core.CImport _ -> map) + String_map.empty program.value.decls + let duplicate_binding env name = match env with | [] -> false @@ -226,6 +246,9 @@ module Semantic = struct | Core.Expression expr -> check_expression_in_context state env loop_depth true expr; env + | Core.CompileAssert compile_assert -> + check_expression state env loop_depth compile_assert.value.cond; + env | Core.Return expr -> Option.iter (check_expression state env loop_depth) expr; (match (return_expected, expr) with @@ -523,7 +546,111 @@ module Semantic = struct | Core.Call call -> check_expression state env loop_depth call.value.target; List.iter (check_expression state env loop_depth) call.value.params; - (match expr_annotation state call.value.target with + (match (call.value.target.value, expr_annotation state call.value.target) with + | Core.Identifier id, _ -> ( + match lookup_function state id.value with + | Some fn_decl when function_has_specialization_param fn_decl -> + let actual_arity = List.length call.value.params in + let required_arity = List.length fn_decl.value.params.value.params in + if actual_arity <> required_arity then + add_diagnostic state Error call.loc + "call argument count does not match the function signature" + | _ -> ()) + | _ -> ()); + (match call.value.target.value with + | Core.Identifier id -> ( + match lookup_function state id.value with + | Some fn_decl when function_has_specialization_param fn_decl -> () + | _ -> ( + match expr_annotation state call.value.target with + | Some { inferred_type = Some ty; _ } -> ( + match ty.value with + | Core.FunctionType fn -> + let actual_arity = List.length call.value.params in + let required_arity = List.length fn.value.param_types in + if + actual_arity < required_arity + || ((not fn.value.vararg) && actual_arity > required_arity) + then + add_diagnostic state Error call.loc + "call argument count does not match the function signature"; + List.iter2 + (fun (arg : Core.expression) expected_ty -> + if arg.value = Core.Nil then + match resolve_core_type state.type_env [] [] arg.loc expected_ty with + | Some resolved when resolved_is_pointerish resolved -> () + | _ -> + add_diagnostic state Error arg.loc + "nil is only valid for pointer-like parameter types") + (List.filteri + (fun index _ -> index < List.length fn.value.param_types) + call.value.params) + (List.filteri + (fun index _ -> index < List.length call.value.params) + fn.value.param_types); + List.iter2 + (fun (arg : Core.expression) expected_ty -> + match + ( expr_annotation state arg, + resolve_core_type state.type_env [] [] arg.loc expected_ty ) + with + | Some { resolved_type = Some actual; _ }, Some expected + when not (resolved_compatible actual expected) + && arg.value <> Core.Nil -> + add_diagnostic state Error arg.loc + "call argument type does not match the function signature" + | _ -> ()) + (List.filteri + (fun index _ -> index < List.length fn.value.param_types) + call.value.params) + (List.filteri + (fun index _ -> index < List.length call.value.params) + fn.value.param_types) + | _ -> ()) + | Some { resolved_type = Some enum_ty; _ } -> ( + match call.value.target.value with + | Core.Literal literal -> ( + match literal.value with + | Core.Enum enum_lit -> ( + match + lookup_enum_variant state.type_env call.loc enum_ty + enum_lit.value.enum_variant.value + with + | Some (_, expected_payloads) -> + if List.length call.value.params <> List.length expected_payloads + then + add_diagnostic state Error call.loc + "enum constructor argument count does not match the variant"; + List.iter2 + (fun (arg : Core.expression) expected_ty -> + match + ( expr_annotation state arg, + expected_ty, + arg.value ) + with + | Some { resolved_type = Some actual; _ }, expected, _ + when not (resolved_compatible actual expected) -> + add_diagnostic state Error arg.loc + "enum constructor argument type does not match the variant" + | _, expected, Core.Nil -> + if not (resolved_is_pointerish expected) then + add_diagnostic state Error arg.loc + "nil is only valid for pointer-like enum payloads" + | _ -> ()) + (List.filteri + (fun index _ -> + index < List.length expected_payloads) + call.value.params) + (List.filteri + (fun index _ -> + index < List.length call.value.params) + expected_payloads) + | None -> ()) + | _ -> ()) + | _ -> ()) + | _ -> ())) + | _ -> + (match expr_annotation state call.value.target with | Some { inferred_type = Some ty; _ } -> ( match ty.value with | Core.FunctionType fn -> @@ -603,7 +730,7 @@ module Semantic = struct | None -> ()) | _ -> ()) | _ -> ()) - | _ -> ()) + | _ -> ())) | Core.Index index -> check_expression state env loop_depth index.value.target; check_expression state env loop_depth index.value.index @@ -706,7 +833,12 @@ module Semantic = struct let run typed : semantic_result = let state = - { typed; diagnostics_rev = []; type_env = type_env_of_program typed.program.program } + { + typed; + diagnostics_rev = []; + type_env = type_env_of_program typed.program.program; + functions = collect_functions typed.program.program; + } in let env = [ initial_scope typed ] in List.iter @@ -717,19 +849,20 @@ module Semantic = struct | None -> () | Some body -> let return_expected = - let core_ty = - Option.value ~default:(void_type fn.loc) fn.value.return_type - in - resolve_core_type state.type_env [] [] fn.loc core_ty + Option.bind fn.value.return_type + (resolve_core_type state.type_env [] [] fn.loc) in let env = push_scope env in let env = List.fold_left (fun env (param : Core.param) -> + let resolved_type = + resolve_core_type state.type_env [] [] param.loc param.value.ty + in bind_current env param.value.name.value { inferred_type = Some param.value.ty; - resolved_type = None; + resolved_type; metavar = metavar_of_type param.value.ty; is_mutable = false; }) @@ -750,19 +883,20 @@ module Semantic = struct | None -> () | Some body -> let return_expected = - let core_ty = - Option.value ~default:(void_type fn.loc) fn.value.return_type - in - resolve_core_type state.type_env [] [] fn.loc core_ty + Option.bind fn.value.return_type + (resolve_core_type state.type_env [] [] fn.loc) in let env = push_scope env in let env = List.fold_left (fun env (param : Core.param) -> + let resolved_type = + resolve_core_type state.type_env [] [] param.loc param.value.ty + in bind_current env param.value.name.value { inferred_type = Some param.value.ty; - resolved_type = None; + resolved_type; metavar = metavar_of_type param.value.ty; is_mutable = false; }) diff --git a/src/lib/ast/analysis_specialize.ml b/src/lib/ast/analysis_specialize.ml new file mode 100644 index 0000000..2456aff --- /dev/null +++ b/src/lib/ast/analysis_specialize.ml @@ -0,0 +1,551 @@ +open Analysis_types + +module Typing = Analysis_typing.Typing + +module Specialize = struct + type instance = { + key : string; + name : string; + template : Core.function_decl; + param_types : resolved_ty list; + return_type : resolved_ty; + } + + type result = { + program : Core.parsed_program; + diagnostics : diagnostic list; + } + + type state = { + typed : typing_result; + functions : Core.function_decl String_map.t; + instances : (string, instance) Hashtbl.t; + mutable queue_rev : instance list; + mutable diagnostics_rev : diagnostic list; + } + + let add_diagnostic state level loc message = + state.diagnostics_rev <- + { category = TypeCheck; level; loc; message } :: state.diagnostics_rev + + let expr_annotation annotations (expr : Core.expression) = + Hashtbl.find_opt annotations.exprs (expr_id expr) + + let resolved_expr_type annotations expr = + Option.bind (expr_annotation annotations expr) (fun ann -> ann.resolved_type) + + let exact_integer annotations expr = + Option.bind (expr_annotation annotations expr) (fun ann -> + Option.bind ann.metavar.integer (fun integer -> integer.exact_value)) + + let is_shape_property name = + String.equal name "dim" || String.equal name "rows" || String.equal name "cols" + + let sanitize_name raw = + let buf = Buffer.create (String.length raw) in + String.iter + (fun ch -> + match ch with + | 'a' .. 'z' | 'A' .. 'Z' | '0' .. '9' -> Buffer.add_char buf ch + | _ -> Buffer.add_char buf '_') + raw; + Buffer.contents buf + + let rec resolved_name = function + | ResolvedInt (Signed, bits) -> Printf.sprintf "i%d" bits + | ResolvedInt (Unsigned, bits) -> Printf.sprintf "u%d" bits + | ResolvedFloat -> "float" + | ResolvedString -> "str" + | ResolvedVoid -> "void" + | ResolvedPointer inner -> "ptr_" ^ resolved_name inner + | ResolvedBox inner -> "box_" ^ resolved_name inner + | ResolvedCell inner -> "cell_" ^ resolved_name inner + | ResolvedArray (inner, count) -> + Printf.sprintf "arr%d_%s" count (resolved_name inner) + | ResolvedVec vec -> Printf.sprintf "fvec%d" vec.dimension + | ResolvedMatrix mat -> Printf.sprintf "mat%dx%d" mat.rows mat.columns + | ResolvedFunction (params, ret, _vararg) -> + "fn_" ^ String.concat "_" (List.map resolved_name params) ^ "_to_" + ^ resolved_name ret + | ResolvedNamed (name, []) -> sanitize_name name + | ResolvedNamed (name, args) -> + sanitize_name name ^ "_" ^ String.concat "_" (List.map resolved_name args) + | ResolvedGenericParam name -> sanitize_name name + | ResolvedVecHole -> "fvec_hole" + | ResolvedMatrixHole -> "mat_hole" + + let instance_key (fn : Core.function_decl) param_types = + function_id fn ^ "::" ^ String.concat "::" (List.map resolved_name param_types) + + let instance_name (fn : Core.function_decl) param_types = + fn.value.name.value ^ "__spec__" ^ String.concat "__" (List.map resolved_name param_types) + + let collect_functions (program : Core.program) = + let add_fn map (fn : Core.function_decl) = + String_map.add fn.value.name.value fn map + in + List.fold_left + (fun map (decl : Core.top_decl) -> + match decl.value with + | Core.FDecl fn -> add_fn map fn + | Core.Foreign foreign -> List.fold_left add_fn map foreign.value.decls + | Core.VDecl _ | Core.TDecl _ | Core.Import _ | Core.CImport _ -> map) + String_map.empty program.value.decls + + let make_state typed = + { + typed; + functions = collect_functions typed.program.program; + instances = Hashtbl.create 32; + queue_rev = []; + diagnostics_rev = []; + } + + let canonical_param_type type_env (param : Core.param) arg_type = + if type_has_specialization_hole param.value.ty then arg_type + else + match resolve_core_type type_env [] [] param.loc param.value.ty with + | Some resolved -> resolved + | None -> arg_type + + let canonical_param_types type_env (fn : Core.function_decl) arg_types = + List.map2 (canonical_param_type type_env) fn.value.params.value.params arg_types + + let enqueue_instance state (fn : Core.function_decl) param_types return_type = + let key = instance_key fn param_types in + match Hashtbl.find_opt state.instances key with + | Some existing -> + if not (equal_resolved_type existing.return_type return_type) then + add_diagnostic state Error fn.loc + (Printf.sprintf + "specialization %s was inferred with incompatible return types" + fn.value.name.value); + existing.name + | None -> + let instance = + { + key; + name = instance_name fn param_types; + template = fn; + param_types; + return_type; + } + in + Hashtbl.add state.instances key instance; + state.queue_rev <- instance :: state.queue_rev; + instance.name + + let clone_identifier (id : Core.identifier) value = { id with value } + + let literal_int loc value = + mk_expr loc (Core.Literal (mk_literal loc (Core.Integer value))) + + let rec rewrite_expression state annotations (expr : Core.expression) : + Core.expression = + match expr.value with + | Core.Binary binary -> + { + expr with + value = + Core.Binary + { + binary with + value = + { + binary.value with + left = rewrite_expression state annotations binary.value.left; + right = rewrite_expression state annotations binary.value.right; + }; + }; + } + | Core.Unary unary -> + { + expr with + value = + Core.Unary + { + unary with + value = + { + unary.value with + inner = rewrite_expression state annotations unary.value.inner; + }; + }; + } + | Core.Block block -> + { expr with value = Core.Block (rewrite_block state annotations block) } + | Core.ToBool inner -> + { expr with value = Core.ToBool (rewrite_expression state annotations inner) } + | Core.Initializer init -> + { + expr with + value = + Core.Initializer + { + init with + value = + { + Core.exprs = + List.map (rewrite_expression state annotations) init.value.exprs; + }; + }; + } + | Core.As cast -> + { + expr with + value = + Core.As + { + cast with + value = + { + cast.value with + inner = rewrite_expression state annotations cast.value.inner; + }; + }; + } + | Core.SizeExpr inner -> + { expr with value = Core.SizeExpr (rewrite_expression state annotations inner) } + | Core.Match match_expr -> + { + expr with + value = + Core.Match + { + match_expr with + value = + { + Core.expr = + rewrite_expression state annotations match_expr.value.expr; + arms = + List.map + (fun (arm : Core.match_arm) -> + { + arm with + value = + { + arm.value with + expr = + rewrite_expression state annotations arm.value.expr; + }; + }) + match_expr.value.arms; + }; + }; + } + | Core.BoxExpr inner -> + { expr with value = Core.BoxExpr (rewrite_expression state annotations inner) } + | Core.Unbox inner -> + { expr with value = Core.Unbox (rewrite_expression state annotations inner) } + | Core.Ref inner -> + { expr with value = Core.Ref (rewrite_expression state annotations inner) } + | Core.Load inner -> + { expr with value = Core.Load (rewrite_expression state annotations inner) } + | Core.Call call -> + rewrite_call state annotations expr call + | Core.Index index -> + { + expr with + value = + Core.Index + { + index with + value = + { + Core.target = + rewrite_expression state annotations index.value.target; + index = rewrite_expression state annotations index.value.index; + }; + }; + } + | Core.Field field -> + let target = rewrite_expression state annotations field.value.target in + if is_shape_property field.value.field.value then + match exact_integer annotations expr with + | Some value -> literal_int expr.loc value + | None -> + add_diagnostic state Error expr.loc + (Printf.sprintf "could not resolve %s to a concrete compile-time value" + field.value.field.value); + { + expr with + value = Core.Field { field with value = { field.value with target } }; + } + else + { + expr with + value = Core.Field { field with value = { field.value with target } }; + } + | Core.Assign write -> + { + expr with + value = + Core.Assign + { + write with + value = + { + Core.target = + rewrite_expression state annotations write.value.target; + value = rewrite_expression state annotations write.value.value; + }; + }; + } + | Core.Mutate write -> + { + expr with + value = + Core.Mutate + { + write with + value = + { + Core.target = + rewrite_expression state annotations write.value.target; + value = rewrite_expression state annotations write.value.value; + }; + }; + } + | (Core.Identifier _ | Core.Literal _ | Core.SizeType _ | Core.Nil | Core.BoxType _) -> + expr + + and rewrite_call state annotations (expr : Core.expression) (call : Core.call) = + let target = rewrite_expression state annotations call.value.target in + let params = List.map (rewrite_expression state annotations) call.value.params in + match call.value.target.value with + | Core.Identifier id -> ( + match String_map.find_opt id.value state.functions with + | Some fn when function_has_specialization_param fn -> ( + let arg_types = List.map (resolved_expr_type annotations) call.value.params in + let return_type = resolved_expr_type annotations expr in + match + ( List.for_all Option.is_some arg_types, + return_type, + List.length arg_types = List.length fn.value.params.value.params ) + with + | true, Some return_type, true -> + let arg_types = List.map Option.get arg_types in + let param_types = + canonical_param_types + (type_env_of_program state.typed.program.program) + fn arg_types + in + let specialized_name = + enqueue_instance state fn param_types return_type + in + { + expr with + value = + Core.Call + { + call with + value = + { + Core.target = + mk_expr call.value.target.loc + (Core.Identifier + (clone_identifier id specialized_name)); + params; + }; + }; + } + | _ -> + add_diagnostic state Error expr.loc + (Printf.sprintf + "could not concretize specialization call to %s before lowering" + id.value); + { expr with value = Core.Call { call with value = { Core.target = target; params } } }) + | _ -> + { expr with value = Core.Call { call with value = { Core.target = target; params } } }) + | _ -> + { expr with value = Core.Call { call with value = { Core.target = target; params } } } + + and rewrite_statement state annotations (stmt : Core.statement) = + let value = + match stmt.value with + | Core.Expression expr -> + Core.Expression (rewrite_expression state annotations expr) + | Core.CompileAssert compile_assert -> + Core.CompileAssert + compile_assert + | Core.Return expr -> + Core.Return (Option.map (rewrite_expression state annotations) expr) + | Core.Defer expr -> + Core.Defer (rewrite_expression state annotations expr) + | Core.Let binding -> + Core.Let + { + binding with + value = + { + binding.value with + init_expr = + rewrite_expression state annotations binding.value.init_expr; + }; + } + | Core.Loop loop -> + Core.Loop + { + loop with + value = + { + loop.value with + init = List.map (rewrite_statement state annotations) loop.value.init; + cond = rewrite_expression state annotations loop.value.cond; + body = rewrite_block state annotations loop.value.body; + step = List.map (rewrite_statement state annotations) loop.value.step; + }; + } + | Core.Break | Core.Continue as value -> value + in + { stmt with value } + + and rewrite_block state annotations (block : Core.block) = + { + block with + value = + { + Core.statements = + List.map (rewrite_statement state annotations) block.value.statements; + result = Option.map (rewrite_expression state annotations) block.value.result; + }; + } + + let binding_of_arg_type loc arg_type = + let ty = core_type_of_resolved_ty loc arg_type in + { + inferred_type = Some ty; + resolved_type = Some arg_type; + metavar = metavar_of_type ty; + is_mutable = false; + } + + let specialize_param (param : Core.param) arg_type = + if type_has_specialization_hole param.value.ty then + { + param with + value = + { + param.value with + ty = core_type_of_resolved_ty param.value.ty.loc arg_type; + }; + } + else param + + let rewrite_function_with_annotations state annotations (fn : Core.function_decl) = + { + fn with + value = + { + fn.value with + definition = Option.map (rewrite_block state annotations) fn.value.definition; + }; + } + + let specialize_instance_decl state (inst : instance) = + let param_bindings = + List.map2 + (fun (param : Core.param) arg_type -> + binding_of_arg_type param.loc arg_type) + inst.template.value.params.value.params inst.param_types + in + let temp_typed, _body_result = + Typing.analyze_function_body state.typed.program + ~active_specializations:[ function_id inst.template ] + ~param_bindings inst.template + in + List.iter + (fun diagnostic -> state.diagnostics_rev <- diagnostic :: state.diagnostics_rev) + (List.rev temp_typed.diagnostics); + let params = + List.map2 specialize_param inst.template.value.params.value.params inst.param_types + in + { + inst.template with + value = + { + inst.template.value with + name = clone_identifier inst.template.value.name inst.name; + params = + { inst.template.value.params with value = { inst.template.value.params.value with params } }; + return_type = + Some (core_type_of_resolved_ty inst.template.loc inst.return_type); + definition = + Option.map + (rewrite_block state temp_typed.annotations) + inst.template.value.definition; + }; + } + + let rec drain_instances state acc = + match state.queue_rev with + | [] -> List.rev acc + | inst :: rest -> + state.queue_rev <- rest; + let fn = specialize_instance_decl state inst in + let decl : Core.top_decl = { Core.value = Core.FDecl fn; loc = fn.loc } in + drain_instances state (decl :: acc) + + let rewrite_decl state annotations (decl : Core.top_decl) = + match decl.value with + | Core.FDecl fn -> + if function_has_specialization_param fn then None + else + Some + { + decl with + value = Core.FDecl (rewrite_function_with_annotations state annotations fn); + } + | Core.Foreign foreign -> + let decls = + List.filter_map + (fun (fn : Core.function_decl) -> + if function_has_specialization_param fn then None + else Some (rewrite_function_with_annotations state annotations fn)) + foreign.value.decls + in + Some + { + decl with + value = Core.Foreign { foreign with value = { foreign.value with decls } }; + } + | Core.VDecl binding -> + Some + { + decl with + value = + Core.VDecl + { + binding with + value = + { + binding.value with + init_expr = + Option.map + (rewrite_expression state annotations) + binding.value.init_expr; + }; + }; + } + | Core.TDecl _ | Core.Import _ | Core.CImport _ -> Some decl + + let run (typed : typing_result) : result = + let state = make_state typed in + let base_decls = + List.filter_map + (rewrite_decl state typed.annotations) + typed.program.program.value.decls + in + let specialized_decls = drain_instances state [] in + { + program = + { + Core.program = + { + typed.program.program with + value = { Core.decls = base_decls @ specialized_decls }; + }; + }; + diagnostics = List.rev state.diagnostics_rev; + } +end diff --git a/src/lib/ast/analysis_types.ml b/src/lib/ast/analysis_types.ml index 5689b0e..0b40d68 100644 --- a/src/lib/ast/analysis_types.ml +++ b/src/lib/ast/analysis_types.ml @@ -50,6 +50,8 @@ type resolved_ty = | ResolvedArray of resolved_ty * int | ResolvedVec of vec_type | ResolvedMatrix of mat_type + | ResolvedVecHole + | ResolvedMatrixHole | ResolvedFunction of resolved_ty list * resolved_ty * bool | ResolvedNamed of string * resolved_ty list | ResolvedGenericParam of string @@ -201,6 +203,33 @@ let numeric_type loc signedness bits = let pointer_type loc inner = mk_type loc (Core.PointerType inner) let box_type loc inner = mk_type loc (Core.BoxType inner) +let rec type_has_specialization_hole (ty : Core.haven_type) = + match ty.value with + | Core.VecHoleType | Core.MatrixHoleType -> true + | Core.CellType inner + | Core.PointerType inner + | Core.BoxType inner -> + type_has_specialization_hole inner + | Core.ArrayType arr -> type_has_specialization_hole arr.value.element + | Core.FunctionType fn -> + type_has_specialization_hole fn.value.return_type + || List.exists type_has_specialization_hole fn.value.param_types + | Core.TemplatedType templ -> + List.exists type_has_specialization_hole templ.value.inner + | Core.NumericType _ + | Core.VecType _ + | Core.MatrixType _ + | Core.FloatType + | Core.VoidType + | Core.StringType + | Core.CustomType _ -> + false + +let function_has_specialization_param (fn : Core.function_decl) = + List.exists + (fun (param : Core.param) -> type_has_specialization_hole param.value.ty) + fn.value.params.value.params + let type_class_of_type (ty : Core.haven_type) = match ty.value with | Core.NumericType { signedness = Unsigned; bits = 1 } -> @@ -214,6 +243,8 @@ let type_class_of_type (ty : Core.haven_type) = | Core.ArrayType _ -> [ TypeClassArray ] | Core.VecType _ -> [ TypeClassVector ] | Core.MatrixType _ -> [ TypeClassMatrix ] + | Core.VecHoleType -> [ TypeClassVector ] + | Core.MatrixHoleType -> [ TypeClassMatrix ] | Core.FunctionType _ -> [ TypeClassFunction ] | Core.CustomType custom -> [ TypeClassCustom custom.name.value ] | Core.CellType _ -> [ TypeClassPointer ] @@ -274,6 +305,9 @@ let rec equal_type (left : Core.haven_type) (right : Core.haven_type) = a.signedness = b.signedness && a.bits = b.bits | Core.VecType a, Core.VecType b -> a = b | Core.MatrixType a, Core.MatrixType b -> a = b + | Core.VecHoleType, Core.VecHoleType + | Core.MatrixHoleType, Core.MatrixHoleType -> + true | Core.FloatType, Core.FloatType | Core.VoidType, Core.VoidType | Core.StringType, Core.StringType -> @@ -327,6 +361,9 @@ let rec equal_resolved_type left right = lc = rc && equal_resolved_type le re | ResolvedVec left, ResolvedVec right -> left = right | ResolvedMatrix left, ResolvedMatrix right -> left = right + | ResolvedVecHole, ResolvedVecHole + | ResolvedMatrixHole, ResolvedMatrixHole -> + true | ResolvedFunction (lp, lr, lv), ResolvedFunction (rp, rr, rv) -> lv = rv && equal_resolved_type lr rr && equal_list equal_resolved_type lp rp @@ -356,6 +393,8 @@ let rec core_type_of_resolved_ty loc = function }) | ResolvedVec vec -> mk_type loc (Core.VecType vec) | ResolvedMatrix mat -> mk_type loc (Core.MatrixType mat) + | ResolvedVecHole -> mk_type loc Core.VecHoleType + | ResolvedMatrixHole -> mk_type loc Core.MatrixHoleType | ResolvedFunction (params, ret, vararg) -> mk_type loc (Core.FunctionType @@ -388,8 +427,8 @@ let resolved_is_bool = function ResolvedInt (Unsigned, 1) -> true | _ -> false let resolved_is_numeric = function ResolvedInt _ | ResolvedFloat -> true | _ -> false -let resolved_is_vector = function ResolvedVec _ -> true | _ -> false -let resolved_is_matrix = function ResolvedMatrix _ -> true | _ -> false +let resolved_is_vector = function ResolvedVec _ | ResolvedVecHole -> true | _ -> false +let resolved_is_matrix = function ResolvedMatrix _ | ResolvedMatrixHole -> true | _ -> false let resolved_is_pointerish = function | ResolvedPointer _ | ResolvedBox _ | ResolvedCell _ | ResolvedString -> true @@ -412,25 +451,51 @@ let resolved_arithmetic_binary_result op left right = ResolvedVec right when left = right -> Some (ResolvedVec left) + | (Core.Add | Core.Subtract | Core.Multiply | Core.Divide | Core.Modulo), + ResolvedVecHole, + ResolvedVecHole -> + Some ResolvedVecHole + | (Core.Add | Core.Subtract), ResolvedVec concrete, ResolvedVecHole + | (Core.Add | Core.Subtract), ResolvedVecHole, ResolvedVec concrete -> + Some (ResolvedVec concrete) | (Core.Multiply | Core.Divide | Core.Modulo), ResolvedVec vec, ResolvedFloat | (Core.Multiply | Core.Divide | Core.Modulo), ResolvedFloat, ResolvedVec vec -> Some (ResolvedVec vec) + | (Core.Multiply | Core.Divide | Core.Modulo), ResolvedVecHole, ResolvedFloat + | (Core.Multiply | Core.Divide | Core.Modulo), ResolvedFloat, ResolvedVecHole -> + Some ResolvedVecHole | (Core.Add | Core.Subtract), ResolvedMatrix left, ResolvedMatrix right when left.rows = right.rows && left.columns = right.columns -> Some (ResolvedMatrix { kind = combine_matrix_kind left right; rows = left.rows; columns = left.columns }) + | (Core.Add | Core.Subtract), ResolvedMatrixHole, ResolvedMatrixHole -> + Some ResolvedMatrixHole + | (Core.Add | Core.Subtract), ResolvedMatrix concrete, ResolvedMatrixHole + | (Core.Add | Core.Subtract), ResolvedMatrixHole, ResolvedMatrix concrete -> + Some (ResolvedMatrix concrete) | Core.Multiply, ResolvedMatrix left, ResolvedMatrix right when left.columns = right.rows -> Some (ResolvedMatrix { kind = combine_matrix_kind left right; rows = left.rows; columns = right.columns }) + | Core.Multiply, ResolvedMatrixHole, ResolvedMatrixHole + | Core.Multiply, ResolvedMatrix _, ResolvedMatrixHole + | Core.Multiply, ResolvedMatrixHole, ResolvedMatrix _ -> + Some ResolvedMatrixHole | Core.Multiply, ResolvedMatrix mat, ResolvedFloat | Core.Multiply, ResolvedFloat, ResolvedMatrix mat -> Some (ResolvedMatrix mat) + | Core.Multiply, ResolvedMatrixHole, ResolvedFloat + | Core.Multiply, ResolvedFloat, ResolvedMatrixHole -> + Some ResolvedMatrixHole | Core.Multiply, ResolvedVec (vec : vec_type), ResolvedMatrix mat when vec.dimension = mat.rows -> Some (ResolvedVec { kind = vec.kind; dimension = mat.columns }) + | Core.Multiply, ResolvedVecHole, ResolvedMatrix _ + | Core.Multiply, ResolvedVec _, ResolvedMatrixHole + | Core.Multiply, ResolvedVecHole, ResolvedMatrixHole -> + Some ResolvedVecHole | _ -> None let rec resolved_compatible actual expected = @@ -546,6 +611,8 @@ and resolve_core_type type_env active subst loc (ty : Core.haven_type) = | Core.VoidType -> Some ResolvedVoid | Core.VecType vec -> Some (ResolvedVec vec) | Core.MatrixType mat -> Some (ResolvedMatrix mat) + | Core.VecHoleType -> Some ResolvedVecHole + | Core.MatrixHoleType -> Some ResolvedMatrixHole | Core.PointerType inner -> Option.map (fun inner -> ResolvedPointer inner) (resolve_core_type type_env active subst loc inner) @@ -679,6 +746,8 @@ let rec resolved_contains_box_ownership type_env active loc ty = | ResolvedVoid | ResolvedVec _ | ResolvedMatrix _ + | ResolvedVecHole + | ResolvedMatrixHole | ResolvedFunction _ | ResolvedGenericParam _ -> false diff --git a/src/lib/ast/analysis_typing.ml b/src/lib/ast/analysis_typing.ml index 70a3064..af3527b 100644 --- a/src/lib/ast/analysis_typing.ml +++ b/src/lib/ast/analysis_typing.ml @@ -8,6 +8,8 @@ module Typing = struct mutable diagnostics_rev : diagnostic list; mutable globals : binding_annotation String_map.t; type_env : type_env; + functions : Core.function_decl String_map.t; + mutable active_specializations : string list; } let add_diagnostic_with_category state category level loc message = @@ -86,9 +88,10 @@ module Typing = struct | Some enum_ty -> lookup_enum_variant state.type_env loc enum_ty variant_name | None -> None - let function_type_of_decl (fn : Core.function_decl) = - let return_type = Option.value ~default:(void_type fn.loc) fn.value.return_type in - mk_type fn.loc + let lookup_function state name = String_map.find_opt name state.functions + + let function_type_with_return loc (fn : Core.function_decl) return_type = + mk_type loc (Core.FunctionType { Core.value = @@ -99,22 +102,89 @@ module Typing = struct return_type; vararg = fn.value.vararg; }; - loc = fn.loc; + loc; }) + let function_type_of_decl (fn : Core.function_decl) = + let return_type = Option.value ~default:(void_type fn.loc) fn.value.return_type in + function_type_with_return fn.loc fn return_type + + let update_function_global state (fn : Core.function_decl) + (return_ann : expr_annotation) = + let fn_ty = + match return_ann.inferred_type with + | Some return_ty -> function_type_with_return fn.loc fn return_ty + | None -> function_type_of_decl fn + in + let param_resolved = + List.map + (fun (param : Core.param) -> + resolve_core_type state.type_env [] [] param.loc param.value.ty) + fn.value.params.value.params + in + let resolved_type = + let rec collect acc = function + | [] -> Some (List.rev acc) + | Some ty :: rest -> collect (ty :: acc) rest + | None :: _ -> None + in + match (collect [] param_resolved, return_ann.resolved_type) with + | Some params, Some ret -> Some (ResolvedFunction (params, ret, fn.value.vararg)) + | _ -> resolve_core_type state.type_env [] [] fn.loc fn_ty + in + state.globals <- + String_map.add fn.value.name.value + { + inferred_type = Some fn_ty; + resolved_type; + metavar = metavar_of_type fn_ty; + is_mutable = false; + } + state.globals + + let shape_property_annotation loc value = + let ty = numeric_type loc Unsigned 32 in + { + inferred_type = Some ty; + resolved_type = Some (ResolvedInt (Unsigned, 32)); + metavar = + metavar_of_type ~constant:(ConstantInt value) + ~integer: + { exact_value = Some value; minimum_bits = Some 32; signedness = Some Unsigned } + ty; + } + + let unknown_shape_property_annotation loc = + let ty = numeric_type loc Unsigned 32 in + { + inferred_type = Some ty; + resolved_type = Some (ResolvedInt (Unsigned, 32)); + metavar = metavar_of_type ty; + } + let collect_globals state (program : Core.program) = let add_global name binding = state.globals <- String_map.add name binding state.globals in let add_function (fn : Core.function_decl) = - let ty = function_type_of_decl fn in - add_global fn.value.name.value - { - inferred_type = Some ty; - resolved_type = resolve_core_type state.type_env [] [] fn.loc ty; - metavar = metavar_of_type ty; - is_mutable = false; - } + match fn.value.return_type with + | Some _ -> + let ty = function_type_of_decl fn in + add_global fn.value.name.value + { + inferred_type = Some ty; + resolved_type = resolve_core_type state.type_env [] [] fn.loc ty; + metavar = metavar_of_type ty; + is_mutable = false; + } + | None -> + add_global fn.value.name.value + { + inferred_type = None; + resolved_type = None; + metavar = unknown_metavar; + is_mutable = false; + } in List.iter (fun (decl : Core.top_decl) -> @@ -133,6 +203,18 @@ module Typing = struct | Core.TDecl _ | Core.Import _ | Core.CImport _ -> ()) program.value.decls + let collect_functions (program : Core.program) = + let add_fn map (fn : Core.function_decl) = + String_map.add fn.value.name.value fn map + in + List.fold_left + (fun map (decl : Core.top_decl) -> + match decl.value with + | Core.FDecl fn -> add_fn map fn + | Core.Foreign foreign -> List.fold_left add_fn map foreign.value.decls + | Core.VDecl _ | Core.TDecl _ | Core.Import _ | Core.CImport _ -> map) + String_map.empty program.value.decls + let rec infer_block state env ?(result_expected : resolved_ty option = None) ~return_expected (block : Core.block) : expr_annotation = @@ -163,6 +245,9 @@ module Typing = struct | Core.Expression expr -> ignore (infer_value_expression state env expr); env + | Core.CompileAssert compile_assert -> + ignore (infer_value_expression state env compile_assert.value.cond); + env | Core.Return expr -> Option.iter (fun (expr : Core.expression) -> @@ -404,28 +489,43 @@ module Typing = struct | None -> unknown_expr_annotation) | None -> unknown_expr_annotation) | Some (ResolvedVec vec) -> ( - match vector_field_index field.value.field.value with - | Some idx when idx < vec.dimension -> - let core_ty = float_type expr.loc in - { - inferred_type = Some core_ty; - resolved_type = Some ResolvedFloat; - metavar = metavar_of_type core_ty; - } - | _ -> unknown_expr_annotation) + match field.value.field.value with + | "dim" -> shape_property_annotation expr.loc vec.dimension + | _ -> ( + match vector_field_index field.value.field.value with + | Some idx when idx < vec.dimension -> + let core_ty = float_type expr.loc in + { + inferred_type = Some core_ty; + resolved_type = Some ResolvedFloat; + metavar = metavar_of_type core_ty; + } + | _ -> unknown_expr_annotation)) | Some (ResolvedMatrix mat) -> ( - match vector_field_index field.value.field.value with - | Some idx when idx < mat.rows -> - let vec_ty = - mk_type expr.loc - (Core.VecType { kind = FloatVec; dimension = mat.columns }) - in - { - inferred_type = Some vec_ty; - resolved_type = - Some (ResolvedVec { kind = FloatVec; dimension = mat.columns }); - metavar = metavar_of_type vec_ty; - } + match field.value.field.value with + | "rows" -> shape_property_annotation expr.loc mat.rows + | "cols" -> shape_property_annotation expr.loc mat.columns + | _ -> ( + match vector_field_index field.value.field.value with + | Some idx when idx < mat.rows -> + let vec_ty = + mk_type expr.loc + (Core.VecType { kind = FloatVec; dimension = mat.columns }) + in + { + inferred_type = Some vec_ty; + resolved_type = + Some (ResolvedVec { kind = FloatVec; dimension = mat.columns }); + metavar = metavar_of_type vec_ty; + } + | _ -> unknown_expr_annotation)) + | Some ResolvedVecHole -> ( + match field.value.field.value with + | "dim" -> unknown_shape_property_annotation expr.loc + | _ -> unknown_expr_annotation) + | Some ResolvedMatrixHole -> ( + match field.value.field.value with + | "rows" | "cols" -> unknown_shape_property_annotation expr.loc | _ -> unknown_expr_annotation) | _ -> unknown_expr_annotation) | Core.Assign write -> @@ -974,6 +1074,63 @@ module Typing = struct } | _ -> unknown_expr_annotation in + let infer_specialized_function_call (fn_decl : Core.function_decl) = + let expected_args = + List.map + (fun (param : Core.param) -> + if type_has_specialization_hole param.value.ty then None + else resolve_core_type state.type_env [] [] param.loc param.value.ty) + fn_decl.value.params.value.params + in + infer_args_with_expected expected_args; + if List.length call.value.params <> List.length fn_decl.value.params.value.params then + unknown_expr_annotation + else + let arg_anns = + List.map (infer_value_expression state env) call.value.params + in + if + List.exists + (fun (ann : expr_annotation) -> ann.resolved_type = None) + arg_anns + then + unknown_expr_annotation + else + let specialization_key = function_id fn_decl in + if List.mem specialization_key state.active_specializations then + unknown_expr_annotation + else + let temp_state = + { + state with + annotations = make_annotations (); + diagnostics_rev = []; + active_specializations = specialization_key :: state.active_specializations; + } + in + let local_env = push_scope [ state.globals ] in + let local_env = + List.fold_left2 + (fun local_env (param : Core.param) (arg_ann : expr_annotation) -> + bind_current local_env param.value.name.value + { + inferred_type = arg_ann.inferred_type; + resolved_type = arg_ann.resolved_type; + metavar = arg_ann.metavar; + is_mutable = false; + }) + local_env fn_decl.value.params.value.params arg_anns + in + let return_expected = + Option.bind fn_decl.value.return_type + (resolve_core_type state.type_env [] [] fn_decl.loc) + in + match fn_decl.value.definition with + | Some body -> + infer_block temp_state local_env ~result_expected:return_expected + ~return_expected body + | None -> unknown_expr_annotation + in match call.value.target.value with | Core.Identifier id -> ( match expected_enum_variant state call.loc expected_type id.value with @@ -982,24 +1139,28 @@ module Typing = struct | Some enum_ty -> infer_enum_constructor enum_ty payload_tys | None -> unknown_expr_annotation) | None -> - let target_ann = infer_expression state env call.value.target in - match target_ann.inferred_type with - | Some ty -> ( - match ty.value with - | Core.FunctionType fn -> - infer_args_with_expected - (List.map - (fun expected_ty -> - resolve_core_type state.type_env [] [] call.loc expected_ty) - fn.value.param_types); - { - inferred_type = Some fn.value.return_type; - resolved_type = - resolve_core_type state.type_env [] [] call.loc fn.value.return_type; - metavar = metavar_of_type fn.value.return_type; - } - | _ -> unknown_expr_annotation) - | None -> unknown_expr_annotation) + (match lookup_function state id.value with + | Some fn_decl when function_has_specialization_param fn_decl -> + infer_specialized_function_call fn_decl + | _ -> + let target_ann = infer_expression state env call.value.target in + match target_ann.inferred_type with + | Some ty -> ( + match ty.value with + | Core.FunctionType fn -> + infer_args_with_expected + (List.map + (fun expected_ty -> + resolve_core_type state.type_env [] [] call.loc expected_ty) + fn.value.param_types); + { + inferred_type = Some fn.value.return_type; + resolved_type = + resolve_core_type state.type_env [] [] call.loc fn.value.return_type; + metavar = metavar_of_type fn.value.return_type; + } + | _ -> unknown_expr_annotation) + | None -> unknown_expr_annotation)) | _ -> let target_ann = infer_expression state env call.value.target in match (call.value.target.value, target_ann.resolved_type) with @@ -1078,6 +1239,13 @@ module Typing = struct resolved_type = Some ResolvedFloat; metavar = metavar_of_type ty; } + | Core.VecHoleType -> + let ty = float_type index.loc in + { + inferred_type = Some ty; + resolved_type = Some ResolvedFloat; + metavar = metavar_of_type ty; + } | Core.MatrixType mat -> let ty = mk_type index.loc (Core.VecType { kind = FloatVec; dimension = mat.columns }) in { @@ -1085,6 +1253,9 @@ module Typing = struct resolved_type = Some (ResolvedVec { kind = FloatVec; dimension = mat.columns }); metavar = metavar_of_type ty; } + | Core.MatrixHoleType -> + let ty = mk_type index.loc Core.VecHoleType in + { inferred_type = Some ty; resolved_type = Some ResolvedVecHole; metavar = metavar_of_type ty } | _ -> unknown_expr_annotation) | None, Some (ResolvedArray (inner, _) | ResolvedPointer inner | ResolvedBox inner | ResolvedCell inner) -> let ty = core_type_of_resolved_ty index.loc inner in @@ -1099,6 +1270,12 @@ module Typing = struct resolved_type = Some (ResolvedVec { kind = FloatVec; dimension = mat.columns }); metavar = metavar_of_type ty; } + | None, Some ResolvedVecHole -> + let ty = float_type index.loc in + { inferred_type = Some ty; resolved_type = Some ResolvedFloat; metavar = metavar_of_type ty } + | None, Some ResolvedMatrixHole -> + let ty = mk_type index.loc Core.VecHoleType in + { inferred_type = Some ty; resolved_type = Some ResolvedVecHole; metavar = metavar_of_type ty } | None, _ -> unknown_expr_annotation and infer_write_like state env ~pointee_target (write : Core.write) : expr_annotation = @@ -1126,18 +1303,66 @@ module Typing = struct { inferred_type = Some ty; resolved_type = Some resolved_type; metavar = metavar_of_type ty } | None, None, None, None -> unknown_expr_annotation - let run (program : Core.parsed_program) = + let make_state (program : Core.parsed_program) = let type_env = type_env_of_program program.program in - let state = - { - annotations = make_annotations (); - diagnostics_rev = []; - globals = String_map.empty; - type_env; - } + { + annotations = make_annotations (); + diagnostics_rev = []; + globals = String_map.empty; + type_env; + functions = collect_functions program.program; + active_specializations = []; + } + + let globals_env state = [ state.globals ] + + let analyze_function_body_with_state state (program : Core.parsed_program) + ?(active_specializations = []) ?param_bindings (fn : Core.function_decl) = + state.active_specializations <- active_specializations; + collect_globals state program.program; + let env = push_scope (globals_env state) in + let return_expected = + Option.bind fn.value.return_type (resolve_core_type state.type_env [] [] fn.loc) in + let env = + match param_bindings with + | Some bindings -> + List.fold_left2 + (fun env (param : Core.param) binding -> + bind_current env param.value.name.value binding) + env fn.value.params.value.params bindings + | None -> + List.fold_left + (fun env (param : Core.param) -> + let binding = + binding_from_type ~is_mutable:false state.type_env param.value.ty + in + bind_current env param.value.name.value binding) + env fn.value.params.value.params + in + match fn.value.definition with + | Some body -> + infer_block state env ~result_expected:return_expected ~return_expected body + | None -> unknown_expr_annotation + + let analyze_function_body (program : Core.parsed_program) + ?(active_specializations = []) ?param_bindings (fn : Core.function_decl) = + let state = make_state program in + let body_result = + analyze_function_body_with_state state program ~active_specializations + ?param_bindings fn + in + ( { + program; + annotations = state.annotations; + diagnostics = List.rev state.diagnostics_rev; + }, + body_result ) + + let run (program : Core.parsed_program) = + let state = make_state program in collect_globals state program.program; - let env = [ state.globals ] in + let globals_env () = globals_env state in List.iter (fun (decl : Core.top_decl) -> match decl.value with @@ -1145,22 +1370,25 @@ module Typing = struct match fn.value.definition with | None -> () | Some body -> - let env = push_scope env in - let return_expected = - Option.bind fn.value.return_type (resolve_core_type state.type_env [] [] fn.loc) - in - let env = - List.fold_left - (fun env (param : Core.param) -> - let binding = + let env = push_scope (globals_env ()) in + let return_expected = + Option.bind fn.value.return_type + (resolve_core_type state.type_env [] [] fn.loc) + in + let env = + List.fold_left + (fun env (param : Core.param) -> + let binding = binding_from_type ~is_mutable:false state.type_env param.value.ty in bind_current env param.value.name.value binding) env fn.value.params.value.params in - ignore - (infer_block state env ~result_expected:return_expected ~return_expected - body)) + let result_ann = + infer_block state env ~result_expected:return_expected ~return_expected body + in + if function_has_specialization_param fn then + update_function_global state fn result_ann) | Core.Foreign foreign -> List.iter (fun (fn : Core.function_decl) -> @@ -1171,7 +1399,7 @@ module Typing = struct Option.bind fn.value.return_type (resolve_core_type state.type_env [] [] fn.loc) in - let env = push_scope env in + let env = push_scope (globals_env ()) in let env = List.fold_left (fun env (param : Core.param) -> @@ -1181,9 +1409,12 @@ module Typing = struct bind_current env param.value.name.value binding) env fn.value.params.value.params in - ignore - (infer_block state env ~result_expected:return_expected - ~return_expected body)) + let result_ann = + infer_block state env ~result_expected:return_expected + ~return_expected body + in + if function_has_specialization_param fn then + update_function_global state fn result_ann) foreign.value.decls | Core.VDecl binding -> Option.iter @@ -1191,7 +1422,7 @@ module Typing = struct let expected_type = resolve_core_type state.type_env [] [] binding.loc binding.value.ty in - ignore (infer_value_expression state env ~expected_type init); + ignore (infer_value_expression state (globals_env ()) ~expected_type init); ignore (record_binding state { diff --git a/src/lib/ast/analysis_verify.ml b/src/lib/ast/analysis_verify.ml index f7713be..f360cd4 100644 --- a/src/lib/ast/analysis_verify.ml +++ b/src/lib/ast/analysis_verify.ml @@ -57,6 +57,7 @@ module Verify = struct and verify_statement state (stmt : Core.statement) = match stmt.value with | Core.Expression expr -> verify_expression state expr + | Core.CompileAssert compile_assert -> verify_expression state compile_assert.value.cond | Core.Return expr -> Option.iter (verify_expression state) expr | Core.Defer expr -> verify_expression state expr | Core.Let binding -> @@ -125,12 +126,12 @@ module Verify = struct verify_expr_annotation state "expression" expr let verify_function state (fn : Core.function_decl) = - let declared_return = - Option.value ~default:(void_type fn.loc) fn.value.return_type - in - verify_declared_type state fn.loc - (Printf.sprintf "function %s return type" fn.value.name.value) - declared_return; + Option.iter + (fun declared_return -> + verify_declared_type state fn.loc + (Printf.sprintf "function %s return type" fn.value.name.value) + declared_return) + fn.value.return_type; List.iter (fun (param : Core.param) -> verify_declared_type state param.loc diff --git a/src/lib/ast/convert.ml b/src/lib/ast/convert.ml index b6710c3..ace75e8 100644 --- a/src/lib/ast/convert.ml +++ b/src/lib/ast/convert.ml @@ -1,6 +1,7 @@ module Cst = Haven_cst.Cst module Surface = Surface_ast module Core = Core_ast +open Haven_core type block_context = [ `Statement | `Value ] @@ -48,6 +49,47 @@ let default_iter_type loc = mk_core_type loc (Core.NumericType { Haven_token.Token.signedness = Haven_token.Token.Signed; bits = 32 }) +let string_of_loc (loc : Loc.t) = + let pos = loc.start_pos in + let col = pos.pos_cnum - pos.pos_bol + 1 in + if pos.pos_fname = "" then Printf.sprintf "%d:%d" pos.pos_lnum col + else Printf.sprintf "%s:%d:%d" pos.pos_fname pos.pos_lnum col + +let rec surface_type_has_specialization_hole (ty : Surface.haven_type) = + match ty.value with + | Surface.VecHoleType | Surface.MatrixHoleType -> true + | Surface.CellType inner + | Surface.PointerType inner + | Surface.BoxType inner -> + surface_type_has_specialization_hole inner + | Surface.ArrayType arr -> surface_type_has_specialization_hole arr.value.element + | Surface.FunctionType fn -> + surface_type_has_specialization_hole fn.value.return_type + || List.exists surface_type_has_specialization_hole fn.value.param_types + | Surface.TemplatedType templ -> + List.exists surface_type_has_specialization_hole templ.value.inner + | Surface.NumericType _ + | Surface.VecType _ + | Surface.MatrixType _ + | Surface.FloatType + | Surface.VoidType + | Surface.StringType + | Surface.CustomType _ -> + false + +let function_has_specialization_param (fn : Surface.function_decl) = + List.exists + (fun (param : Surface.param) -> + surface_type_has_specialization_hole param.value.ty) + fn.value.params.value.params + +let validate_surface_function_decl (fn : Surface.function_decl) = + if fn.value.return_type = None && not (function_has_specialization_param fn) then + failwith + (Printf.sprintf + "function %s omits its return type, but only specialization functions may infer returns (%s)" + fn.value.name.value (string_of_loc fn.loc)) + let rec cst_program_to_surface (program : Cst.program) : Surface.program = let decls = List.map cst_top_decl_to_surface program.value.decls in mk_surface program.loc { Surface.decls } @@ -77,7 +119,9 @@ and cst_function_decl_to_surface (fn : Cst.function_decl) : Surface.function_dec vararg = fn.value.vararg; } in - mk_surface fn.loc value + let decl = mk_surface fn.loc value in + validate_surface_function_decl decl; + decl and cst_intrinsic_to_surface (intr : Cst.intrinsic) : Surface.intrinsic = let value = @@ -215,6 +259,18 @@ and cst_statement_to_surface (stmt : Cst.statement) : Surface.statement option = name = cst_identifier_to_surface binding.value.name; init_expr = cst_expr_to_surface binding.value.init_expr; })) + | Cst.CompileAssert compile_assert -> + Some + (Surface.CompileAssert + (mk_surface compile_assert.loc + { + Surface.cond = cst_expr_to_surface compile_assert.value.cond; + message = + { + value = compile_assert.value.message.value; + loc = compile_assert.value.message.loc; + }; + })) | Cst.Return expr -> Some (Surface.Return (Option.map cst_expr_to_surface expr)) | Cst.Defer expr -> Some (Surface.Defer (cst_expr_to_surface expr)) | Cst.Iter iter -> @@ -362,6 +418,8 @@ and cst_type_to_surface (ty : Cst.haven_type) : Surface.haven_type = | Cst.NumericType n -> Surface.NumericType n | Cst.VecType v -> Surface.VecType v | Cst.MatrixType m -> Surface.MatrixType m + | Cst.VecHoleType -> Surface.VecHoleType + | Cst.MatrixHoleType -> Surface.MatrixHoleType | Cst.FloatType -> Surface.FloatType | Cst.VoidType -> Surface.VoidType | Cst.StringType -> Surface.StringType @@ -522,6 +580,8 @@ let rec surface_type_to_core (ty : Surface.haven_type) : Core.haven_type = | Surface.NumericType n -> Core.NumericType n | Surface.VecType v -> Core.VecType v | Surface.MatrixType m -> Core.MatrixType m + | Surface.VecHoleType -> Core.VecHoleType + | Surface.MatrixHoleType -> Core.MatrixHoleType | Surface.FloatType -> Core.FloatType | Surface.VoidType -> Core.VoidType | Surface.StringType -> Core.StringType @@ -822,6 +882,20 @@ and surface_statement_to_core st (stmt : Surface.statement) : Core.statement lis init_expr = surface_expr_to_core st binding.value.init_expr; })); ] + | Surface.CompileAssert compile_assert -> + [ + mk_core_stmt stmt.loc + (Core.CompileAssert + (mk_core compile_assert.loc + { + Core.cond = surface_expr_to_core st compile_assert.value.cond; + message = + { + value = compile_assert.value.message.value; + loc = compile_assert.value.message.loc; + }; + })); + ] | Surface.Return expr -> [ mk_core_stmt stmt.loc (Core.Return (Option.map (surface_expr_to_core st) expr)) ] | Surface.Defer expr -> diff --git a/src/lib/ast/core_ast.ml b/src/lib/ast/core_ast.ml index b5ca4f9..54cb581 100644 --- a/src/lib/ast/core_ast.ml +++ b/src/lib/ast/core_ast.ml @@ -43,6 +43,8 @@ and haven_type_desc = | NumericType of numeric_type | VecType of vec_type | MatrixType of mat_type + | VecHoleType + | MatrixHoleType | FloatType | VoidType | StringType @@ -119,9 +121,13 @@ and foreign = foreign_desc node and block_desc = { statements : statement list; result : expression option } and block = block_desc node +and compile_assert_desc = { cond : expression; message : string node } +and compile_assert = compile_assert_desc node + and statement_desc = | Expression of expression | Let of let_stmt + | CompileAssert of compile_assert | Return of expression option | Defer of expression | Loop of loop_stmt diff --git a/src/lib/ast/dune b/src/lib/ast/dune index 183ad5c..3c6594f 100644 --- a/src/lib/ast/dune +++ b/src/lib/ast/dune @@ -1,8 +1,8 @@ (library (name haven_ast) (public_name haven.ast) - (modules analysis_cfold analysis analysis_cleanup analysis_ownership - analysis_purity analysis_semantic analysis_types analysis_typing + (modules analysis_asserts analysis_cfold analysis analysis_cleanup analysis_ownership + analysis_purity analysis_semantic analysis_specialize analysis_types analysis_typing analysis_verify ast cimport convert core_ast imports llvm_ir platform_defaults platform_defaults_common platform_defaults_darwin platform_defaults_unix pretty surface_ast) diff --git a/src/lib/ast/imports.ml b/src/lib/ast/imports.ml index 8cdfc7d..6685517 100644 --- a/src/lib/ast/imports.ml +++ b/src/lib/ast/imports.ml @@ -8,15 +8,17 @@ type state = { active : (string, unit) Hashtbl.t; search_dirs : string list; sysroot : string option; + import_text_resolver : (string -> string option) option; mutable diagnostics_rev : diagnostic list; } -let create_state ?(search_dirs = []) ?sysroot () = +let create_state ?(search_dirs = []) ?sysroot ?import_text_resolver () = { seen = Hashtbl.create 32; active = Hashtbl.create 32; search_dirs; sysroot; + import_text_resolver; diagnostics_rev = []; } @@ -110,7 +112,14 @@ and expand_import state ~current_file import_path loc = ~finally:(fun () -> Hashtbl.remove state.active key) (fun () -> try - let imported = Parser.parse_file resolved in + let imported = + match state.import_text_resolver with + | Some resolve_text -> ( + match resolve_text resolved with + | Some text -> Parser.parse_string ~filename:resolved text + | None -> Parser.parse_file resolved) + | None -> Parser.parse_file resolved + in let expanded = expand_program state imported in Hashtbl.add state.seen key (); expanded.program.value.decls @@ -129,8 +138,11 @@ and expand_cimport state ~current_file import_path loc = List.rev_append (List.rev expanded.diagnostics) state.diagnostics_rev; expanded.decls -let expand_cst ?(search_dirs = []) ?sysroot parsed = +let expand_cst ?(search_dirs = []) ?sysroot ?import_text_resolver parsed = let defaults = Platform_defaults.resolve ~search_dirs ?sysroot () in - let state = create_state ~search_dirs:defaults.search_dirs ?sysroot:defaults.sysroot () in + let state = + create_state ~search_dirs:defaults.search_dirs ?sysroot:defaults.sysroot + ?import_text_resolver () + in let parsed = expand_program state parsed in { parsed; diagnostics = List.rev state.diagnostics_rev } diff --git a/src/lib/ast/llvm_ir.ml b/src/lib/ast/llvm_ir.ml index 6f5086d..d2d7c1b 100644 --- a/src/lib/ast/llvm_ir.ml +++ b/src/lib/ast/llvm_ir.ml @@ -133,6 +133,8 @@ let rec mangle_resolved_ty = function vec.Haven_token.Token.dimension | ResolvedMatrix mat -> Printf.sprintf "mat.%d.%d" mat.rows mat.columns + | ResolvedVecHole -> "vec.hole" + | ResolvedMatrixHole -> "mat.hole" | ResolvedFunction (params, ret, vararg) -> String.concat "." ([ "fn" ] @@ -278,6 +280,10 @@ let rec llvm_type_of_resolved t ?loc = function Llvm.array_type (llvm_type_of_resolved t ?loc inner) count | ResolvedVec vec -> llvm_vector_type t vec.Haven_token.Token.dimension | ResolvedMatrix mat -> llvm_matrix_flat_type t mat + | ResolvedVecHole -> + fail ?loc "specialization vector hole reached LLVM lowering" + | ResolvedMatrixHole -> + fail ?loc "specialization matrix hole reached LLVM lowering" | (ResolvedNamed _ as resolved) -> llvm_named_type t ?loc resolved | ResolvedGenericParam name -> fail ?loc "unresolved generic parameter %s reached LLVM lowering" name @@ -819,7 +825,10 @@ let rec emit_ownership_on_storage t kind resolved storage = | Analysis.ResolvedFunction _ | Analysis.ResolvedGenericParam _ -> () - | Analysis.ResolvedVec _ | Analysis.ResolvedMatrix _ -> + | Analysis.ResolvedVec _ + | Analysis.ResolvedMatrix _ + | Analysis.ResolvedVecHole + | Analysis.ResolvedMatrixHole -> () and emit_ownership_on_named t kind resolved storage = @@ -1869,6 +1878,8 @@ and emit_statement t (stmt : Core.statement) = match stmt.value with | Core.Expression expr -> ignore (emit_expr t expr) + | Core.CompileAssert _ -> + fail ~loc:stmt.loc "compile-time assert reached LLVM lowering" | Core.Let binding -> let resolved = binding_resolved_type t binding in let slot = diff --git a/src/lib/ast/pretty.ml b/src/lib/ast/pretty.ml index abbe422..7013456 100644 --- a/src/lib/ast/pretty.ml +++ b/src/lib/ast/pretty.ml @@ -66,6 +66,8 @@ let rec pp_surface_type fmt (ty : Surface.haven_type) = | Surface.NumericType n -> fprintf fmt "%s" (numeric_type_to_string n) | VecType v -> fprintf fmt "%s" (vec_type_to_string v) | MatrixType m -> fprintf fmt "%s" (mat_type_to_string m) + | VecHoleType -> fprintf fmt "fvec?" + | MatrixHoleType -> fprintf fmt "mat?" | FloatType -> fprintf fmt "float" | VoidType -> fprintf fmt "void" | StringType -> fprintf fmt "str" @@ -189,6 +191,9 @@ and pp_surface_statement fmt (stmt : Surface.statement) = binding.value.mut pp_surface_identifier binding.value.name (pp_print_option pp_surface_type) binding.value.ty pp_surface_expression binding.value.init_expr + | CompileAssert compile_assert -> + fprintf fmt "CompileAssert(cond=%a, message=%S)" pp_surface_expression + compile_assert.value.cond compile_assert.value.message.value | Return expr -> fprintf fmt "Return(%a)" (pp_print_option pp_surface_expression) expr | Defer expr -> fprintf fmt "Defer(%a)" pp_surface_expression expr @@ -294,6 +299,8 @@ let rec pp_core_type fmt (ty : Core.haven_type) = | Core.NumericType n -> fprintf fmt "%s" (numeric_type_to_string n) | VecType v -> fprintf fmt "%s" (vec_type_to_string v) | MatrixType m -> fprintf fmt "%s" (mat_type_to_string m) + | VecHoleType -> fprintf fmt "fvec?" + | MatrixHoleType -> fprintf fmt "mat?" | FloatType -> fprintf fmt "float" | VoidType -> fprintf fmt "void" | StringType -> fprintf fmt "str" @@ -419,6 +426,9 @@ and pp_core_statement fmt (stmt : Core.statement) = binding.value.mut pp_core_identifier binding.value.name (pp_print_option pp_core_type) binding.value.ty pp_core_expression binding.value.init_expr + | CompileAssert compile_assert -> + fprintf fmt "CompileAssert(cond=%a, message=%S)" pp_core_expression + compile_assert.value.cond compile_assert.value.message.value | Return expr -> fprintf fmt "Return(%a)" (pp_print_option pp_core_expression) expr | Defer expr -> fprintf fmt "Defer(%a)" pp_core_expression expr | Loop loop -> diff --git a/src/lib/ast/surface_ast.ml b/src/lib/ast/surface_ast.ml index 35002d1..7a70631 100644 --- a/src/lib/ast/surface_ast.ml +++ b/src/lib/ast/surface_ast.ml @@ -45,6 +45,8 @@ and haven_type_desc = | NumericType of numeric_type | VecType of vec_type | MatrixType of mat_type + | VecHoleType + | MatrixHoleType | FloatType | VoidType | StringType @@ -121,9 +123,13 @@ and foreign = foreign_desc node and block_desc = { statements : statement list; result : expression option } and block = block_desc node +and compile_assert_desc = { cond : expression; message : string node } +and compile_assert = compile_assert_desc node + and statement_desc = | Expression of expression | Let of let_stmt + | CompileAssert of compile_assert | Return of expression option | Defer of expression | Iter of iter_stmt diff --git a/src/lib/cst/cst.ml b/src/lib/cst/cst.ml index 34c16c2..8947085 100644 --- a/src/lib/cst/cst.ml +++ b/src/lib/cst/cst.ml @@ -99,6 +99,8 @@ and haven_type_desc = | NumericType of numeric_type | VecType of vec_type | MatrixType of mat_type + | VecHoleType + | MatrixHoleType | FloatType | VoidType | StringType @@ -181,9 +183,13 @@ and block_item_desc = and block_item = block_item_desc node +and compile_assert_desc = { cond : expression; message : string node } +and compile_assert = compile_assert_desc node + and statement_desc = | Expression of expression | Let of let_stmt + | CompileAssert of compile_assert | Return of expression option | Defer of expression | Iter of iter_stmt diff --git a/src/lib/cst/emit.ml b/src/lib/cst/emit.ml index 923d090..a43b9c8 100644 --- a/src/lib/cst/emit.ml +++ b/src/lib/cst/emit.ml @@ -264,6 +264,8 @@ and emit_type fmt ty = | NumericType n -> fprintf fmt "%s" (numeric_type_to_string n) | VecType v -> fprintf fmt "%s" (vec_type_to_string v) | MatrixType m -> fprintf fmt "%s" (mat_type_to_string m) + | VecHoleType -> fprintf fmt "fvec?" + | MatrixHoleType -> fprintf fmt "mat?" | FloatType -> fprintf fmt "float" | VoidType -> fprintf fmt "void" | StringType -> fprintf fmt "str" @@ -409,6 +411,11 @@ and emit_statement ~indent ~comments fmt stmt = fprintf fmt " = %a;" (emit_expression ~ctx_prec:0 ~indent ~comments) s.init_expr + | CompileAssert a -> + fprintf fmt "@assert %a, %S;" + (emit_expression ~ctx_prec:0 ~indent ~comments) + a.value.cond a.value.message.value; + flush_inline_on_line ~line:stmt.loc.start_pos.pos_lnum comments fmt | Return (Some e) -> fprintf fmt "ret %a;" (emit_expression ~ctx_prec:0 ~indent ~comments) e; flush_inline_on_line ~line:stmt.loc.start_pos.pos_lnum comments fmt diff --git a/src/lib/cst/locate.ml b/src/lib/cst/locate.ml index c631362..67253fe 100644 --- a/src/lib/cst/locate.ml +++ b/src/lib/cst/locate.ml @@ -94,7 +94,14 @@ let add_if predicate node acc = if predicate node then node :: acc else acc let rec walk_haven_type predicate acc (ty : haven_type) = let acc = add_if predicate (HavenType ty) acc in match ty.value with - | NumericType _ | VecType _ | MatrixType _ | FloatType | VoidType | StringType + | NumericType _ + | VecType _ + | MatrixType _ + | VecHoleType + | MatrixHoleType + | FloatType + | VoidType + | StringType -> acc | CustomType _ -> acc @@ -219,6 +226,7 @@ and walk_statement predicate acc stmt = | Some t -> walk_haven_type predicate acc t in walk_expression predicate acc s.value.init_expr + | CompileAssert a -> walk_expression predicate acc a.value.cond | Return (Some e) -> walk_expression predicate acc e | Return None -> acc | Defer e -> walk_expression predicate acc e diff --git a/src/lib/cst/pretty.ml b/src/lib/cst/pretty.ml index 3b96504..afef00f 100644 --- a/src/lib/cst/pretty.ml +++ b/src/lib/cst/pretty.ml @@ -68,6 +68,8 @@ and pp_type fmt ty = | NumericType n -> fprintf fmt "%s" (numeric_type_to_string n) | VecType v -> fprintf fmt "%s" (vec_type_to_string v) | MatrixType m -> fprintf fmt "%s" (mat_type_to_string m) + | VecHoleType -> fprintf fmt "fvec?" + | MatrixHoleType -> fprintf fmt "mat?" | FloatType -> fprintf fmt "float" | VoidType -> fprintf fmt "void" | StringType -> fprintf fmt "str" @@ -192,6 +194,9 @@ and pp_statement fmt stmt = let s = unwrap s in fprintf fmt "@[Let(@,mut=%a,@ name=%a,@ init_expr=%a@,)@]" pp_print_bool s.mut pp_identifier s.name pp_expression s.init_expr + | CompileAssert a -> + fprintf fmt "@[CompileAssert(@,cond=%a,@ message=%S@,)@]" pp_expression + a.value.cond a.value.message.value | Return (Some e) -> fprintf fmt "@[Return(@,%a@,)@]" pp_expression e | Return None -> fprintf fmt "Return" | Defer e -> fprintf fmt "@[Defer(@,%a@,)@]" pp_expression e diff --git a/src/lib/lexer/lexer.ml b/src/lib/lexer/lexer.ml index 4d2cd23..fc723a1 100644 --- a/src/lib/lexer/lexer.ml +++ b/src/lib/lexer/lexer.ml @@ -51,13 +51,19 @@ type symbol = | Tilde | Underscore +type directive = + | Assert + module Raw = struct type t = | Trivia of trivia | Ident of string + | Directive of directive | Numeric_type of numeric_type | Vec_type of vec_type | Mat_type of mat_type + | Vec_hole_type + | Mat_hole_type | Float_type | Void_type | Str_type @@ -83,11 +89,15 @@ let ident_inner = [%sedlex.regexp? letter | digit | '_'] let ident_segment = [%sedlex.regexp? Plus ident_inner] let numeric_type = [%sedlex.regexp? ('i' | 'u'), nonzero, Star digit] let vec_type = [%sedlex.regexp? "fvec", nonzero, Star digit] +let vec_hole_type = [%sedlex.regexp? "fvec?"] +let assert_directive = [%sedlex.regexp? "@assert"] let mat_type = [%sedlex.regexp? ("fmat" | "mat"), nonzero, Star digit, 'x', nonzero, Star digit] +let mat_hole_type = [%sedlex.regexp? "mat?"] + let float_type = [%sedlex.regexp? "float"] let void_type = [%sedlex.regexp? "void"] let str_type = [%sedlex.regexp? "str"] @@ -359,9 +369,12 @@ let rec lex buf acc = | numeric_type -> let text = Sedlexing.Utf8.lexeme buf in lex buf (push_token buf (Numeric_type (numeric_type_of_string text)) acc) + | vec_hole_type -> lex buf (push_token buf Vec_hole_type acc) + | assert_directive -> lex buf (push_token buf (Directive Assert) acc) | vec_type -> let text = Sedlexing.Utf8.lexeme buf in lex buf (push_token buf (Vec_type (vec_type_of_string text)) acc) + | mat_hole_type -> lex buf (push_token buf Mat_hole_type acc) | mat_type -> let text = Sedlexing.Utf8.lexeme buf in lex buf (push_token buf (Mat_type (mat_type_of_string text)) acc) diff --git a/src/lib/lexer/pretty.ml b/src/lib/lexer/pretty.ml index 1e1e544..19fa2b5 100644 --- a/src/lib/lexer/pretty.ml +++ b/src/lib/lexer/pretty.ml @@ -98,10 +98,13 @@ let pp_token (token : Raw.tok) = match token.tok with | Trivia trivia -> pp_trivia trivia | Ident text -> Printf.printf "IDENT %s\n" text + | Directive Assert -> Printf.printf "DIRECTIVE @assert\n" | Numeric_type desc -> Printf.printf "NUMERIC_TYPE %s\n" (numeric_type_to_string desc) | Vec_type desc -> Printf.printf "VEC_TYPE %s\n" (vec_type_to_string desc) | Mat_type desc -> Printf.printf "MAT_TYPE %s\n" (mat_type_to_string desc) + | Vec_hole_type -> Printf.printf "VEC_HOLE_TYPE fvec?\n" + | Mat_hole_type -> Printf.printf "MAT_HOLE_TYPE mat?\n" | Float_type -> Printf.printf "FLOAT_TYPE float\n" | Void_type -> Printf.printf "VOID_TYPE void\n" | Str_type -> Printf.printf "STR_TYPE str\n" diff --git a/src/lib/parser/grammar.mly b/src/lib/parser/grammar.mly index 5bb0ba5..e2bc7db 100644 --- a/src/lib/parser/grammar.mly +++ b/src/lib/parser/grammar.mly @@ -15,7 +15,9 @@ %token NUMERIC_TYPE %token VEC_TYPE %token MAT_TYPE +%token VEC_HOLE_TYPE MAT_HOLE_TYPE %token FLOAT_TYPE VOID_TYPE STR_TYPE +%token ASSERT_DIRECTIVE %token INT_LIT %token FLOAT_LIT %token HEX_LIT OCT_LIT BIN_LIT @@ -163,6 +165,9 @@ stmt_inner: Let (mk_loc $startpos $endpos { mut = m; name = n; ty = Some t; init_expr = mk_expr $startpos(i) $endpos(i) (Initializer i); }) } | LET m=boption(MUT) t=haven_type n=identifier EQUAL e=expr { Let (mk_loc $startpos $endpos { mut = m; name = n; ty = Some t; init_expr = e; }) } + | ASSERT_DIRECTIVE c=expr COMMA m=STRING_LIT { + CompileAssert (mk_loc $startpos $endpos { cond = c; message = mk_id m $startpos(m) $endpos(m) }) + } | RET e=option(expr) { Return e } | DEFER e=expr { Defer e } | ITER r=iter_range v=identifier b=block { Iter (mk_loc $startpos $endpos { range = r; var = v; body = b }) } @@ -332,6 +337,8 @@ builtin_type: | t=NUMERIC_TYPE { mk_loc $startpos $endpos (NumericType t) } | t=VEC_TYPE { mk_loc $startpos $endpos (VecType t) } | t=MAT_TYPE { mk_loc $startpos $endpos (MatrixType t) } + | VEC_HOLE_TYPE { mk_loc $startpos $endpos VecHoleType } + | MAT_HOLE_TYPE { mk_loc $startpos $endpos MatrixHoleType } | FLOAT_TYPE { mk_loc $startpos $endpos FloatType } | VOID_TYPE { mk_loc $startpos $endpos VoidType } | STR_TYPE { mk_loc $startpos $endpos StringType } diff --git a/src/lib/parser/parser.ml b/src/lib/parser/parser.ml index 3d155db..d3d543f 100644 --- a/src/lib/parser/parser.ml +++ b/src/lib/parser/parser.ml @@ -83,9 +83,12 @@ let rec next_token st : Grammar.token * Lexing.position * Lexing.position = (* Skip trivia by recursion or a loop *) next_token st | Ident s -> store_token st (keyword_or_ident s, startp, endp) + | Directive Assert -> store_token st (Grammar.ASSERT_DIRECTIVE, startp, endp) | Numeric_type s -> store_token st (Grammar.NUMERIC_TYPE s, startp, endp) | Vec_type s -> store_token st (Grammar.VEC_TYPE s, startp, endp) | Mat_type s -> store_token st (Grammar.MAT_TYPE s, startp, endp) + | Vec_hole_type -> store_token st (Grammar.VEC_HOLE_TYPE, startp, endp) + | Mat_hole_type -> store_token st (Grammar.MAT_HOLE_TYPE, startp, endp) | Float_type -> store_token st (Grammar.FLOAT_TYPE, startp, endp) | Void_type -> store_token st (Grammar.VOID_TYPE, startp, endp) | Str_type -> store_token st (Grammar.STR_TYPE, startp, endp) @@ -165,6 +168,9 @@ let token_to_string = function Printf.sprintf "vector type %s" (vec_type_to_string desc) | Grammar.MAT_TYPE desc -> Printf.sprintf "matrix type %s" (mat_type_to_string desc) + | Grammar.VEC_HOLE_TYPE -> "vector specialization type fvec?" + | Grammar.MAT_HOLE_TYPE -> "matrix specialization type mat?" + | Grammar.ASSERT_DIRECTIVE -> "@assert" | Grammar.VOID_TYPE -> "void" | Grammar.FLOAT_TYPE -> "float" | Grammar.STR_TYPE -> "str" diff --git a/src/test/test_llvm_ir.ml b/src/test/test_llvm_ir.ml index 151f19c..13c780d 100644 --- a/src/test/test_llvm_ir.ml +++ b/src/test/test_llvm_ir.ml @@ -5,6 +5,7 @@ let emit_ir source = assert_no_diagnostics "llvm ir typing" pipeline.typing.diagnostics; assert_no_diagnostics "llvm ir verify" pipeline.verify.diagnostics; assert_no_diagnostics "llvm ir semantic" pipeline.semantic.diagnostics; + assert_no_diagnostics "llvm ir asserts" pipeline.asserts.diagnostics; assert_no_diagnostics "llvm ir purity" pipeline.purity.diagnostics; assert_no_diagnostics "llvm ir ownership" pipeline.ownership.diagnostics; Haven.Ast.Llvm_ir.emit_ir_string pipeline @@ -106,4 +107,60 @@ pub fn main() -> void {} assert_true "non-constant vector literals should be assembled with insertelement" (string_contains literal_ir "define <3 x float> @make_vec"); assert_true "non-constant matrix literals should lower to their flat vector form" - (string_contains literal_ir "define <4 x float> @make_mat") + (string_contains literal_ir "define <4 x float> @make_mat"); + + let specialization_ir = + emit_ir + {| +fn vadd(fvec? a, fvec? b) { a + b } +pub fn main() -> fvec3 { vadd(Vec<1.0, 2.0, 3.0>, Vec<4.0, 5.0, 6.0>) } +|} + in + assert_true "specialization should clone concrete vector variants before LLVM" + (string_contains specialization_ir "@vadd__spec__fvec3__fvec3"); + assert_true "specialized vector addition should lower with concrete vector ops" + (string_contains specialization_ir "fadd <3 x float>"); + + let shape_property_ir = + emit_ir + {| +fn width(mat? m) { m.cols } +pub fn main() -> u32 { width(Mat, Vec<3.0, 4.0>>) } +|} + in + assert_true "shape property specialization should lower concrete matrix helpers" + (string_contains shape_property_ir "@width__spec__mat2x2"); + assert_true "shape properties should lower as plain integer constants" + (string_contains shape_property_ir "store i32 2"); + + let get_mat_row_ir = + emit_ir + {| +fn get_mat_row(mat? m, u32 row) { m[row] } +pub fn main() -> fvec3 { + get_mat_row(Mat, Vec<4.0, 5.0, 6.0>>, 1) +} +|} + in + assert_true "get_mat_row should clone a concrete helper before LLVM" + (string_contains get_mat_row_ir "@get_mat_row__spec__mat2x3"); + assert_true "specialized matrix row access should still lower through row addressing" + (string_contains get_mat_row_ir "getelementptr inbounds float"); + + let compile_assert_ir = + emit_ir + {| +fn mat_width_eq(mat? a, mat? b) { + @assert a.cols == b.cols, "matrix widths must match"; + a.cols +} + +pub fn main() -> u32 { + mat_width_eq(Mat, Vec<3.0, 4.0>>, Mat, Vec<7.0, 8.0>>) +} +|} + in + assert_true "successful compile asserts should not reach LLVM" + (not (string_contains compile_assert_ir "compile-time assert")); + assert_true "compile assert specializations should still lower normally" + (string_contains compile_assert_ir "@mat_width_eq__spec__mat2x2__mat2x2") diff --git a/src/test/test_parser.ml b/src/test/test_parser.ml index a8550ac..bcc06b2 100644 --- a/src/test/test_parser.ml +++ b/src/test/test_parser.ml @@ -13,6 +13,12 @@ let run () = assert_parse_ok "matrix literal accepts vector expressions" "pub fn main() -> i32 { let v = Vec<3.0, 4.0>; let x = Mat, v>; 0 }"; + assert_parse_ok "specialization hole parameter types" + "fn vadd(fvec? a, mat? b) { a }"; + + assert_parse_ok "compile-time assert statement" + "fn vadd(fvec? a, fvec? b) { @assert a.dim == b.dim, \"dims must match\"; a + b }"; + assert_parse_ok "multi-payload enum variants" "type Pair = enum { Both(i32, i32), Empty }; pub fn main() -> i32 { 0 }"; @@ -56,4 +62,37 @@ let run () = assert_true "parse errors should report remapped preprocessor filename" (string_contains parse_error "preprocessed/input.hv:7:"); assert_true "parse errors should report remapped preprocessor line" - (string_contains parse_error "at preprocessed/input.hv:7:") + (string_contains parse_error "at preprocessed/input.hv:7:"); + + let specialized_core = parse_to_core "fn vadd(fvec? a, mat? b) { a }" in + let specialized_fn = find_named_function "vadd" specialized_core in + (match specialized_fn.value.params.value.params with + | [ vec_param; mat_param ] -> ( + match (vec_param.value.ty.value, mat_param.value.ty.value) with + | Core.VecHoleType, Core.MatrixHoleType -> () + | _ -> failwith "expected specialization hole parameter types to survive into core AST") + | _ -> failwith "expected vadd to have two parameters"); + assert_true "specialization function should keep omitted return type through core conversion" + (specialized_fn.value.return_type = None); + + let assert_core = + parse_to_core + "fn vadd(fvec? a, fvec? b) { @assert a.dim == b.dim, \"dims must match\"; a + b }" + in + let assert_fn = find_named_function "vadd" assert_core in + (match + Option.bind assert_fn.value.definition (fun body -> + match body.value.statements with stmt :: _ -> Some stmt | [] -> None) + with + | Some { value = Core.CompileAssert _; _ } -> () + | _ -> failwith "expected compile assert to survive into core AST"); + + let missing_return_error = + try + ignore (parse_to_core "fn add(i32 a, i32 b) { a + b }"); + failwith "expected omitted non-specialization return type to fail" + with Failure msg -> msg + in + assert_true "omitted return type should require specialization parameters" + (string_contains missing_return_error + "only specialization functions may infer returns") diff --git a/src/test/test_pipeline.ml b/src/test/test_pipeline.ml index 725b315..b87c30e 100644 --- a/src/test/test_pipeline.ml +++ b/src/test/test_pipeline.ml @@ -39,8 +39,11 @@ let run () = | _ -> failwith "expected constant folding to reduce the if condition to true") | _ -> failwith "expected constant folding to reduce the if condition to a literal"); (match scrutinee_after.value with - | Core.Unary _ -> () - | _ -> failwith "expected cleanup to remove redundant ToBool around bool-valued condition"); + | Core.Literal lit -> ( + match lit.value with + | Core.Bool true -> () + | _ -> failwith "expected cleanup to preserve the folded boolean condition") + | _ -> failwith "expected cleanup to preserve the folded boolean condition"); let folded_pipeline = parse_to_core "pub fn main() -> i32 { let x = 1 + 2 * 3; x }" @@ -54,4 +57,171 @@ let run () = match lit.value with | Core.Integer 7 -> () | _ -> failwith "expected arithmetic constant folding to produce 7") - | _ -> failwith "expected arithmetic constant folding to produce a literal") + | _ -> failwith "expected arithmetic constant folding to produce a literal"); + + let specialization_pipeline = + parse_to_core + "fn vadd(fvec? a, fvec? b) { a + b }\n\ + fn width(mat? m) { m.cols }\n\ + fn main() -> i32 { 0 }" + |> Analysis.Pipeline.run_core + in + assert_no_diagnostics "specialization typing" specialization_pipeline.typing.diagnostics; + assert_no_diagnostics "specialization verify" specialization_pipeline.verify.diagnostics; + assert_no_diagnostics "specialization semantic" specialization_pipeline.semantic.diagnostics; + + let specialization_call_pipeline = + parse_to_core + "fn vadd(fvec? a, fvec? b) { a + b }\n\ + fn main() -> fvec3 {\n\ + \ vadd(Vec<1.0, 2.0, 3.0>, Vec<4.0, 5.0, 6.0>)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_no_diagnostics "specialization call typing" + specialization_call_pipeline.typing.diagnostics; + assert_no_diagnostics "specialization call verify" + specialization_call_pipeline.verify.diagnostics; + assert_no_diagnostics "specialization call semantic" + specialization_call_pipeline.semantic.diagnostics; + let specialization_core = + Haven.Ast.Pretty.core_program_to_string specialization_call_pipeline.cleaned + in + assert_true "specialized pipeline should emit a concrete clone" + (string_contains specialization_core "vadd__spec__fvec3__fvec3"); + assert_true "specialized pipeline should erase hole types from the lowered program" + (not (string_contains specialization_core "fvec?")); + + let shape_property_pipeline = + parse_to_core + "fn width(mat? m) { m.cols }\n\ + pub fn main() -> u32 {\n\ + \ width(Mat, Vec<3.0, 4.0>>)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_no_diagnostics "shape property typing" shape_property_pipeline.typing.diagnostics; + assert_no_diagnostics "shape property verify" shape_property_pipeline.verify.diagnostics; + assert_no_diagnostics "shape property semantic" shape_property_pipeline.semantic.diagnostics; + let shape_property_core = + Haven.Ast.Pretty.core_program_to_string shape_property_pipeline.cleaned + in + assert_true "shape property specialization should clone the function" + (string_contains shape_property_core "width__spec__mat2x2"); + assert_true "shape properties should lower to integer literals before LLVM" + (string_contains shape_property_core "Literal(2)"); + + let get_mat_row_pipeline = + parse_to_core + "fn get_mat_row(mat? m, u32 row) { m[row] }\n\ + pub fn main() -> fvec3 {\n\ + \ get_mat_row(Mat, Vec<4.0, 5.0, 6.0>>, 1)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_no_diagnostics "get_mat_row typing" get_mat_row_pipeline.typing.diagnostics; + assert_no_diagnostics "get_mat_row verify" get_mat_row_pipeline.verify.diagnostics; + assert_no_diagnostics "get_mat_row semantic" get_mat_row_pipeline.semantic.diagnostics; + let get_mat_row_core = + Haven.Ast.Pretty.core_program_to_string get_mat_row_pipeline.cleaned + in + assert_true "get_mat_row should specialize to a concrete matrix helper" + (string_contains get_mat_row_core "get_mat_row__spec__mat2x3"); + assert_true "get_mat_row specialization should infer a concrete vector return" + (string_contains get_mat_row_core "return=fvec3"); + + let specialization_dedup_pipeline = + parse_to_core + "fn get_mat_row(mat? m, u32 row) { m[row] }\n\ + pub fn main() -> fvec3 {\n\ + \ let u32 row = 1;\n\ + \ get_mat_row(Mat, Vec<4.0, 5.0, 6.0>>, row)\n\ + }\n\ + pub fn other() -> fvec3 {\n\ + \ get_mat_row(Mat, Vec<10.0, 11.0, 12.0>>, 1)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_no_diagnostics "specialization dedup typing" + specialization_dedup_pipeline.typing.diagnostics; + let specialization_dedup_core = + Haven.Ast.Pretty.core_program_to_string specialization_dedup_pipeline.cleaned + in + assert_true "equivalent concrete signatures should reuse one specialization" + (count_occurrences specialization_dedup_core + "name=get_mat_row__spec__mat2x3__u32" + = 1); + + let compile_assert_pipeline = + parse_to_core + "fn mat_width_eq(mat? a, mat? b) {\n\ + \ @assert a.cols == b.cols, \"matrix widths must match\";\n\ + \ a.cols\n\ + }\n\ + pub fn main() -> u32 {\n\ + \ mat_width_eq(Mat, Vec<3.0, 4.0>>, Mat, Vec<7.0, 8.0>>)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_no_diagnostics "compile assert typing" compile_assert_pipeline.typing.diagnostics; + assert_no_diagnostics "compile assert verify" compile_assert_pipeline.verify.diagnostics; + assert_no_diagnostics "compile assert semantic" compile_assert_pipeline.semantic.diagnostics; + assert_no_diagnostics "compile assert pass" compile_assert_pipeline.asserts.diagnostics; + let compile_assert_core = + Haven.Ast.Pretty.core_program_to_string compile_assert_pipeline.cleaned + in + assert_true "successful compile asserts should be erased before the cleaned AST" + (not (string_contains compile_assert_core "CompileAssert")); + + let compile_assert_fail_pipeline = + parse_to_core + "fn mat_width_eq(mat? a, mat? b) {\n\ + \ @assert a.cols == b.cols, \"matrix widths must match\";\n\ + \ a.cols\n\ + }\n\ + pub fn main() -> u32 {\n\ + \ mat_width_eq(Mat, Vec<3.0, 4.0>>, Mat, Vec<6.0>>)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_has_diagnostics "failing compile assert should produce diagnostics" + compile_assert_fail_pipeline.asserts.diagnostics; + assert_any_diagnostic_message_contains "failing compile assert should preserve the user message" + "matrix widths must match" compile_assert_fail_pipeline.asserts.diagnostics; + assert_any_diagnostic_message_contains + "failing compile assert should include the rendered condition" + "compile-time assertion failed: a.cols == b.cols" + compile_assert_fail_pipeline.asserts.diagnostics; + assert_any_diagnostic_message_contains + "failing compile assert should include the specialized condition" + "specialized as: 2 == 1" + compile_assert_fail_pipeline.asserts.diagnostics; + + let compile_assert_short_circuit_pipeline = + parse_to_core + "fn vadd(fvec? a, fvec? b) {\n\ + \ @assert a.dim == b.dim, \"vector dimensions must match\";\n\ + \ a + b\n\ + }\n\ + pub fn main() -> fvec3 {\n\ + \ vadd(Vec<1.0, 2.0, 3.0>, Vec<4.0, 5.0>)\n\ + }" + |> Analysis.Pipeline.run_core + in + assert_has_diagnostics "failing compile assert should still report the assert" + compile_assert_short_circuit_pipeline.asserts.diagnostics; + assert_any_diagnostic_message_contains + "failing compile assert should keep the user-facing message" + "vector dimensions must match" + compile_assert_short_circuit_pipeline.asserts.diagnostics; + assert_any_diagnostic_message_contains + "failing compile assert should include the rendered vector condition" + "compile-time assertion failed: a.dim == b.dim" + compile_assert_short_circuit_pipeline.asserts.diagnostics; + assert_any_diagnostic_message_contains + "failing compile assert should include the specialized vector condition" + "specialized as: 3 == 2" + compile_assert_short_circuit_pipeline.asserts.diagnostics; + assert_no_diagnostics + "failing compile assert should short-circuit later semantic analysis" + compile_assert_short_circuit_pipeline.semantic.diagnostics diff --git a/src/test/test_rc.ml b/src/test/test_rc.ml index 4cfc693..fb7e943 100644 --- a/src/test/test_rc.ml +++ b/src/test/test_rc.ml @@ -184,6 +184,7 @@ let pipeline_errors (pipeline : Analysis.Pipeline.result) = pipeline.typing.diagnostics @ pipeline.verify.diagnostics @ pipeline.semantic.diagnostics + @ pipeline.asserts.diagnostics @ pipeline.purity.diagnostics @ pipeline.ownership.diagnostics in @@ -198,6 +199,29 @@ let format_diagnostic (diagnostic : Analysis.diagnostic) = let col = loc.start_pos.Lexing.pos_cnum - loc.start_pos.Lexing.pos_bol + 1 in Printf.sprintf "%s:%d:%d: %s" file line col diagnostic.message +let require_compile_error ~source ~expected_substrings = + let parsed = Haven.Parser.parse_file source in + let pipeline = Analysis.Pipeline.run_cst parsed in + let diagnostics = pipeline_errors pipeline in + match diagnostics with + | [] -> + failwith + (Printf.sprintf "expected compiler diagnostics while compiling %s" source) + | diagnostic :: _ -> + List.iter + (fun expected -> + if not (Test_support.string_contains diagnostic.message expected) then + failwith + (Printf.sprintf + "expected diagnostic for %s to include %S, but got:\n%s" + source expected (format_diagnostic diagnostic))) + expected_substrings; + if pipeline.semantic.diagnostics <> [] then + failwith + (Printf.sprintf + "expected failing compile-time assert to short-circuit later semantic diagnostics for %s" + source) + let compile_case_to_object ~source ~output_path opt_level = let parsed = Haven.Parser.parse_file source in let pipeline = Analysis.Pipeline.run_cst parsed in @@ -259,6 +283,16 @@ let run_case temp_dir harness_obj root (case : rc_case) (opt : opt_case) = let run () = let root = resolve_repo_root () in + let specialization_assert_fail = + Filename.concat root "tests/inputs/specialization_assert_fail.hv" + in + require_compile_error ~source:specialization_assert_fail + ~expected_substrings: + [ + "vector dimensions must match"; + "compile-time assertion failed: a.dim == b.dim"; + "specialized as: 3 == 2"; + ]; Test_support.with_temp_dir "haven-rc" (fun temp_dir -> let harness_c = Filename.concat temp_dir "rc_harness.c" in let harness_obj = Filename.concat temp_dir "rc_harness.o" in diff --git a/src/test/test_support.ml b/src/test/test_support.ml index 5324705..473df97 100644 --- a/src/test/test_support.ml +++ b/src/test/test_support.ml @@ -20,6 +20,18 @@ let string_contains haystack needle = in loop 0 +let count_occurrences haystack needle = + let haystack_len = String.length haystack in + let needle_len = String.length needle in + let rec loop index count = + if needle_len = 0 then count + else if index + needle_len > haystack_len then count + else if String.sub haystack index needle_len = needle then + loop (index + needle_len) (count + 1) + else loop (index + 1) count + in + loop 0 0 + let assert_diagnostic_message_contains label needle diagnostics = match diagnostics with | [] -> failwith (label ^ " expected at least one diagnostic") @@ -87,6 +99,7 @@ let find_first_let_binding (program : Core.parsed_program) = match stmt.value with | Core.Let binding -> collect_in_statements (binding :: acc) rest | Core.Expression _ + | Core.CompileAssert _ | Core.Return _ | Core.Defer _ | Core.Break @@ -132,6 +145,7 @@ let find_let_binding_at index (program : Core.parsed_program) = let inner = collect_stmt [] loop.value.body.value.statements in collect_stmt (List.rev_append inner bindings) stmt_rest | Core.Expression _ + | Core.CompileAssert _ | Core.Return _ | Core.Defer _ | Core.Break diff --git a/tests/inputs/specialization_assert_fail.hv b/tests/inputs/specialization_assert_fail.hv new file mode 100644 index 0000000..bd18783 --- /dev/null +++ b/tests/inputs/specialization_assert_fail.hv @@ -0,0 +1,9 @@ +fn vadd(fvec? a, fvec? b) { + @assert a.dim == b.dim, "vector dimensions must match"; + a + b +} + +pub fn sut() -> i32 { + let sum = vadd(Vec<1.0, 2.0, 3.0>, Vec<4.0, 5.0>); + as(sum.x) +}