diff --git a/lib/irgen.ml b/lib/irgen.ml index d02f06d..bf5d1d9 100644 --- a/lib/irgen.ml +++ b/lib/irgen.ml @@ -18,6 +18,29 @@ type variable = let udt_structs : (string, L.lltype) Hashtbl.t = Hashtbl.create 10 let udt_field_indices : (string, (string * int) list) Hashtbl.t = Hashtbl.create 10 +let build_entry_alloca the_function var_name var_type = + let builder = + L.builder_at context (L.instr_begin (L.entry_block the_function)) + in + L.build_alloca var_type var_name builder + +let build_itr_meta_data (_, itr_object) = + match itr_object with + | SList elements -> (List.length elements, elements) + | STuple elements -> (List.length elements, elements) + | _ -> failwith "Not a list or tuple" + +let rec sast_type_of_resolved_type (rty ) = + match rty with + | RInt -> Sast.RInt + | RBool -> Sast.RBool + | RChar -> Sast.RChar + | RFloat -> Sast.RFloat + | RString -> Sast.RString + | RUnit -> Sast.RUnit + | RList t -> Sast.RList (sast_type_of_resolved_type t) + | _ -> raise (Failure ("IR ERROR: TBI ")) + let l_int = L.i32_type context and l_bool = L.i1_type context and l_char = L.i8_type context @@ -505,6 +528,7 @@ let add_terminal builder instr = let translate blocks = let the_module = L.create_module context "Fly" in let local_vars = StringMap.empty in + let block_map = StringMap.empty in let var_types = StringMap.empty in List.iter (function @@ -538,7 +562,7 @@ let translate blocks = let lfunc, _, blocks = func_block in let curr_func = Some lfunc in let builder = L.builder_at_end context (L.entry_block lfunc) in - process_blocks blocks vars var_types curr_func [] (Some builder) + process_blocks blocks vars var_types curr_func [] (Some builder) block_map and process_blocks blocks vars @@ -546,13 +570,14 @@ let translate blocks = (curr_func : L.llvalue option) func_blocks (builder : L.llbuilder option) + block_map = match blocks with (* We've declared all objects, lets fill in all function bodies *) | [] -> process_func_blocks func_blocks vars var_types | block :: rest -> - let updated_vars, updated_var_types, updated_curr_func, u_func_blocks, u_builder = - process_block block vars var_types curr_func func_blocks builder + let updated_vars, updated_var_types, updated_curr_func, u_func_blocks, u_builder, block_map = + process_block block vars var_types curr_func func_blocks builder block_map in process_blocks rest @@ -561,6 +586,7 @@ let translate blocks = updated_curr_func u_func_blocks u_builder + block_map and process_block block vars @@ -568,49 +594,50 @@ let translate blocks = (curr_func : L.llvalue option) func_blocks (builder : L.llbuilder option) + block_map = match block with | SUDTDef (name, members) -> define_udt_type name members; - vars, var_types, curr_func, func_blocks, builder + vars, var_types, curr_func, func_blocks, builder, block_map | SDeclTyped (id, typ, expr) -> - if Option.is_some curr_func - then ( + if Option.is_some curr_func then ( let new_vars = add_local_val typ id vars var_types expr the_module (Option.get builder) in let new_var_types = StringMap.add id typ var_types in - new_vars, new_var_types, curr_func, func_blocks, builder) + new_vars, new_var_types, curr_func, func_blocks, builder, block_map + ) else ( let new_vars = add_global_val typ id vars var_types expr the_module in let new_var_types = StringMap.add id typ var_types in - new_vars, new_var_types, curr_func, func_blocks, builder) + new_vars, new_var_types, curr_func, func_blocks, builder, block_map) | SFunctionDefinition (typ, id, formals, body) -> let u_func_blocks = declare_function typ id formals body func_blocks in - vars, var_types, curr_func, u_func_blocks, builder + vars, var_types, curr_func, u_func_blocks, builder, block_map | SReturnUnit -> ignore (L.build_ret_void (Option.get builder)); - vars, var_types, curr_func, func_blocks, builder + vars, var_types, curr_func, func_blocks, builder, block_map | SReturnVal expr -> let ret = build_expr expr vars var_types the_module (Option.get builder) in ignore (L.build_ret ret (Option.get builder)); - vars, var_types, curr_func, func_blocks, builder + vars, var_types, curr_func, func_blocks, builder, block_map | SExpr expr -> ignore (build_expr expr vars var_types the_module (Option.get builder)); - vars, var_types, curr_func, func_blocks, builder + vars, var_types, curr_func, func_blocks, builder, block_map | SIfEnd (expr, blks) -> let bool_val = build_expr expr vars var_types the_module (Option.get builder) in (* We require curr_func to be Some - no if-else in global scope *) let then_bb = L.append_block context "then" (Option.get curr_func) in let then_builder = Some (L.builder_at_end context then_bb) in - ignore (process_blocks blks vars var_types curr_func func_blocks then_builder); + ignore (process_blocks blks vars var_types curr_func func_blocks then_builder block_map); let end_bb = L.append_block context "if_end" (Option.get curr_func) in let build_br_end = L.build_br end_bb in add_terminal (L.builder_at_end context then_bb) build_br_end; ignore (L.build_cond_br bool_val then_bb end_bb (Option.get builder)); let u_builder = Some (L.builder_at_end context end_bb) in - vars, var_types, curr_func, func_blocks, u_builder + vars, var_types, curr_func, func_blocks, u_builder, block_map | SIfNonEnd (expr, blks, else_blk) -> (* expression should be bool *) assert_types (fst expr) RBool; @@ -619,7 +646,7 @@ let translate blocks = let then_bb = L.append_block context "then" (Option.get curr_func) in let then_builder = Some (L.builder_at_end context then_bb) in - ignore (process_blocks blks vars var_types curr_func func_blocks then_builder); + ignore (process_blocks blks vars var_types curr_func func_blocks then_builder block_map); let end_bb = L.append_block context "if_end" (Option.get curr_func) in let else_bb = L.append_block context "else" (Option.get curr_func) in let else_builder = Some (L.builder_at_end context else_bb) in @@ -630,7 +657,7 @@ let translate blocks = let build_br_end = L.build_br end_bb in add_terminal (L.builder_at_end context then_bb) build_br_end; add_terminal (L.builder_at_end context else_bb) build_br_end; - vars, var_types, curr_func, func_blocks, u_builder + vars, var_types, curr_func, func_blocks, u_builder, block_map | SEnumDeclaration (enum_name_str, sast_variants) -> let rec process_variants_to_update_vars current_vars_map @@ -642,7 +669,6 @@ let translate blocks = | SEnumVariantDefault variant_n :: rest -> let assigned_int_val = current_int_val in let llvm_const_i32 = L.const_int l_int assigned_int_val in - let global_llvm_var_name = enum_name_str ^ "::" ^ variant_n in let global_llvm_var_ptr = L.define_global global_llvm_var_name llvm_const_i32 the_module @@ -683,9 +709,99 @@ let translate blocks = in process_variants_to_update_vars updated_vars_map rest (assigned_int_val + 1) in - let updated_vars = process_variants_to_update_vars vars sast_variants 0 in - updated_vars, var_types, curr_func, func_blocks, builder + updated_vars, var_types, curr_func, func_blocks, builder, block_map + + | SFor (loop_var, checked_iterable, checked_body) -> + + (* let _prt = List.hd checked_body in + let () = match _prt with + | SExpr _sp -> + let () = match snd _sp with + | SFunctionCall (_id, sexprs) -> + let arg = List.hd sexprs in + Printf.printf "ARG: %s, %s\n" (Utils.string_of_resolved_type (fst arg)) (Utils.string_of_sexpr (snd arg)); + + | _ -> failwith "bad" in + failwith "bad" + | _ -> failwith "really bad" in *) + + + (* checked_itrable is a tuple of the following for: + (list of types of all elements in t , list of elements all of same type )*) + let curr_func = Option.get curr_func in + let builder = Option.get builder in + let elem_type = fst checked_iterable in + let v_type = sast_type_of_resolved_type elem_type in + + let v_type = match v_type with + | RList t -> t + | _ -> failwith "this shouldn't ever happen" in + + Printf.printf "VTYPE: %s\n" (Utils.string_of_resolved_type v_type); + (* so this should be constructing a tuple of the form {list length, ptr to list head } + TODO: ensure this is the case! *) + let list_data_ptr = build_expr checked_iterable vars var_types the_module builder in + let data_obj = + match snd checked_iterable with + | SList lst -> lst + | lst -> raise (Failure ("IR ERROR: Expression '" ^ Utils.string_of_sexpr lst ^ "' has type " ^ Utils.string_of_resolved_type (fst checked_iterable) ^ " and is not iterable. ")) + in + let llvm_list_length = L.const_int l_int (List.length data_obj) in + + + (* LOOP ITERABLE VARIABLE COUNTER *) + (* ATM this instantiates a new builder and inserts it into a new block within the current function *) + (* this i_val is the idx of the curr pos in the iterable*) + let index_alloca = build_entry_alloca curr_func "i" l_int in + ignore (L.build_store (L.const_int l_int 0) index_alloca builder); + + (* make loop variable visibile to the for loop body *) + let loop_var_alloca = build_entry_alloca curr_func loop_var (ltype_of_typ v_type) in + Printf.printf "LOOP_VAR: %s\n" (L.string_of_llvalue loop_var_alloca); + let vars = StringMap.add loop_var { v_value = loop_var_alloca; v_type = v_type; v_scope = Local } vars in + + let loop_cond_bb = L.append_block context "loop_cond" curr_func in + let loop_body_bb = L.append_block context "loop_body" curr_func in + let loop_after_bb = L.append_block context "loop_after" curr_func in + ignore (L.build_br loop_cond_bb builder); + + (* COMPARE CNTR (itr idx) to list length *) + (* load the curr ival into the body and determine if it exeeds the length of the iterable + jump to the end, or continue in the body as requried *) + let cond_builder = L.builder_at_end context loop_cond_bb in + let curr_i = L.build_load index_alloca "i_val" cond_builder in + let cond = L.build_icmp L.Icmp.Slt curr_i llvm_list_length "loop_cond" cond_builder in + ignore (L.build_cond_br cond loop_body_bb loop_after_bb cond_builder); + + (* IF NOT at EOL *) + let body_builder = L.builder_at_end context loop_body_bb in + (* Get current element from list: list_data_ptr[i] *) + let elem_ptr = L.build_gep list_data_ptr [| curr_i |] "elem_ptr" body_builder in + let elem_val = L.build_load elem_ptr "elem_val" body_builder in + (* Store current element into loop variable *) + ignore (L.build_store elem_val loop_var_alloca body_builder); + (* Run loop body *) + let block_map = StringMap.add "break" loop_after_bb block_map in + ignore (process_blocks checked_body vars var_types (Some curr_func) func_blocks (Some body_builder) block_map); + (* Increment i and jump back to condition *) + let updated_body_builder = L.builder_at_end context (L.insertion_block body_builder) in + let next_i = L.build_add curr_i (L.const_int l_int 1) "i_plus_1" updated_body_builder in + ignore (L.build_store next_i index_alloca updated_body_builder); + ignore (L.build_br loop_cond_bb updated_body_builder); + + (* IF PASS/ REACHED EOL *) + let after_builder = L.builder_at_end context loop_after_bb in + vars, var_types, Some curr_func, func_blocks, Some after_builder, block_map + | SBreak -> + let exit_bb = (try + StringMap.find "break" block_map + with Not_found -> + (* this should never actually run becuase sast ensures breaks are only in loops *) + raise (Failure "Break cannot be placed outside of a loop")) in + (* go to the while loop exit branch *) + ignore (L.build_br exit_bb (Option.get builder)); + vars, var_types, curr_func, func_blocks, builder, block_map | b -> raise (Failure @@ -701,7 +817,7 @@ let translate blocks = = match block with | SElseEnd blks -> - ignore (process_blocks blks vars var_types curr_func func_blocks builder); + ignore (process_blocks blks vars var_types curr_func func_blocks builder block_map); let u_builder = Some (L.builder_at_end context end_bb) in u_builder | SElifEnd (expr, blks) -> @@ -711,7 +827,7 @@ let translate blocks = let then_bb = L.append_block context "then" (Option.get curr_func) in let then_builder = Some (L.builder_at_end context then_bb) in - ignore (process_blocks blks vars var_types curr_func func_blocks then_builder); + ignore (process_blocks blks vars var_types curr_func func_blocks then_builder block_map); let build_br_end = L.build_br end_bb in add_terminal (L.builder_at_end context then_bb) build_br_end; ignore (L.build_cond_br bool_val then_bb end_bb (Option.get builder)); @@ -724,7 +840,7 @@ let translate blocks = let then_bb = L.append_block context "then" (Option.get curr_func) in let then_builder = Some (L.builder_at_end context then_bb) in - ignore (process_blocks blks vars var_types curr_func func_blocks then_builder); + ignore (process_blocks blks vars var_types curr_func func_blocks then_builder block_map); let build_br_end = L.build_br end_bb in add_terminal (L.builder_at_end context then_bb) build_br_end; let else_bb = L.append_block context "else" (Option.get curr_func) in @@ -743,6 +859,7 @@ let translate blocks = let func_blocks = [] in (* ..and start off with no builder.. *) let builder = None in - process_blocks blocks local_vars var_types curr_func func_blocks builder; + process_blocks blocks local_vars var_types curr_func func_blocks builder block_map; the_module + ;; diff --git a/lib/semant.ml b/lib/semant.ml index 30ef5b6..94334e8 100644 --- a/lib/semant.ml +++ b/lib/semant.ml @@ -670,8 +670,9 @@ and check_block block envs special_blocks func_ret_type = let checked_iterable = check_expr iterable envs special_blocks in let t, _ = checked_iterable in (match t with - | RList _ | RTuple _ -> - let new_var_env = var_dec_helper loop_var RInt envs in + | RList list_typ -> + Printf.printf "SEMANT: %s\n" (Utils.string_of_resolved_type list_typ); + let new_var_env = var_dec_helper loop_var list_typ envs in let updated_envs = { envs with var_env = new_var_env } in let updated_special_blocks = StringSet.add "break" (StringSet.add "continue" special_blocks) diff --git a/test/ir/dune b/test/ir/dune index 22d382e..9769a74 100644 --- a/test/ir/dune +++ b/test/ir/dune @@ -26,6 +26,10 @@ (name test_enums) (libraries fly_lib ounit2 str)) +(test + (name test_for) + (libraries fly_lib ounit2)) + (test (name test_list) (libraries fly_lib ounit2 str)) diff --git a/test/ir/test_for.ml b/test/ir/test_for.ml new file mode 100644 index 0000000..2339889 --- /dev/null +++ b/test/ir/test_for.ml @@ -0,0 +1,132 @@ +open OUnit2 +open Fly_lib +module L = Llvm + +let get_sast input = + try + let lexbuf = Lexing.from_string input in + let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in + let sast = Fly_lib.Semant.check ast.body in + sast + with + | err -> + raise + (Failure + (Printf.sprintf + "Error generating sast, is your program correct?: error=%s" + (Printexc.to_string err))) +;; + +let printer = fun s -> "\n---\n" ^ s ^ "\n---\n" + +let _write_to_file text filename = + let channel = open_out filename in + Printf.fprintf channel "%s" text; + close_out channel +;; + +let tests = + "testing_ir" + >::: [ ("test1" + >:: fun _ -> + let sast = + get_sast "fun function() -> () { for y := [10, 20, 30] { break; } }" + in + let mdl = Irgen.translate sast in + let actual = L.string_of_llmodule mdl in + let expected = + "; ModuleID = 'Fly'\n\ + source_filename = \"Fly\"\n\n\ + define void @function() {\n\ + entry:\n\ + \ %y = alloca i32, align 4\n\ + \ %i = alloca i32, align 4\n\ + \ %list = alloca i32, i32 3, align 4\n\ + \ %index = getelementptr inbounds i32, i32* %list, i32 0\n\ + \ store i32 10, i32* %index, align 4\n\ + \ %index1 = getelementptr inbounds i32, i32* %list, i32 1\n\ + \ store i32 20, i32* %index1, align 4\n\ + \ %index2 = getelementptr inbounds i32, i32* %list, i32 2\n\ + \ store i32 30, i32* %index2, align 4\n\ + \ store i32 0, i32* %i, align 4\n\ + \ br label %loop_cond\n\n\ + loop_cond: ; preds = %loop_body, \ + %entry\n\ + \ %i_val = load i32, i32* %i, align 4\n\ + \ %loop_cond3 = icmp slt i32 %i_val, 3\n\ + \ br i1 %loop_cond3, label %loop_body, label %loop_after\n\n\ + loop_body: ; preds = %loop_cond\n\ + \ %elem_ptr = getelementptr i32, i32* %list, i32 %i_val\n\ + \ %elem_val = load i32, i32* %elem_ptr, align 4\n\ + \ store i32 %elem_val, i32* %y, align 4\n\ + \ br label %loop_after\n\ + \ %i_plus_1 = add i32 %i_val, 1\n\ + \ store i32 %i_plus_1, i32* %i, align 4\n\ + \ br label %loop_cond\n\n\ + loop_after: ; preds = %loop_body, \ + %loop_cond\n\ + \ ret void\n\ + }\n" + in + assert_equal expected actual ~printer) + ; ("test2" + >:: fun _ -> + let sast = + get_sast "fun function() -> () { for y := [10, 20, 30] { print(y); } }" + in + let mdl = Irgen.translate sast in + let actual = L.string_of_llmodule mdl in + let expected = + "; ModuleID = 'Fly'\n\ + source_filename = \"Fly\"\n\n\ + @int_fmt = private unnamed_addr constant [4 x i8] c\"%d\\0A\\00\", align 1\n\n\ + define void @function() {\n\ + entry:\n\ + \ %y = alloca i32, align 4\n\ + \ %i = alloca i32, align 4\n\ + \ %list = alloca i32, i32 3, align 4\n\ + \ %index = getelementptr inbounds i32, i32* %list, i32 0\n\ + \ store i32 10, i32* %index, align 4\n\ + \ %index1 = getelementptr inbounds i32, i32* %list, i32 1\n\ + \ store i32 20, i32* %index1, align 4\n\ + \ %index2 = getelementptr inbounds i32, i32* %list, i32 2\n\ + \ store i32 30, i32* %index2, align 4\n\ + \ store i32 0, i32* %i, align 4\n\ + \ br label %loop_cond\n\n\ + loop_cond: ; preds = %loop_body, \ + %entry\n\ + \ %i_val = load i32, i32* %i, align 4\n\ + \ %loop_cond3 = icmp slt i32 %i_val, 3\n\ + \ br i1 %loop_cond3, label %loop_body, label %loop_after\n\n\ + loop_body: ; preds = %loop_cond\n\ + \ %elem_ptr = getelementptr i32, i32* %list, i32 %i_val\n\ + \ %elem_val = load i32, i32* %elem_ptr, align 4\n\ + \ store i32 %elem_val, i32* %y, align 4\n\ + \ %y4 = load i32, i32* %y, align 4\n\ + \ %call_printf = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 \ + x i8], [4 x i8]* @int_fmt, i32 0, i32 0), i32 %y4)\n\ + \ %i_plus_1 = add i32 %i_val, 1\n\ + \ store i32 %i_plus_1, i32* %i, align 4\n\ + \ br label %loop_cond\n\n\ + loop_after: ; preds = %loop_cond\n\ + \ ret void\n\ + }\n\n\ + declare i32 @printf(i8*, ...)\n" + in + assert_equal expected actual ~printer) + + ; ("test2" + >:: fun _ -> + let sast = + get_sast "fun function() -> () { let a := [1,2,3]; for y := a { print(y); } }" + in + let mdl = Irgen.translate sast in + let actual = L.string_of_llmodule mdl in + let expected = + "" + in + assert_equal expected actual ~printer) + ] +;; + +let _ = run_test_tt_main tests diff --git a/test/type_checker/dune b/test/type_checker/dune index ae38a74..6f60b7b 100644 --- a/test/type_checker/dune +++ b/test/type_checker/dune @@ -17,3 +17,7 @@ (test (name test_function) (libraries fly_lib ounit2)) + + (test + (name test_for) + (libraries fly_lib ounit2)) diff --git a/test/type_checker/test_for.ml b/test/type_checker/test_for.ml new file mode 100644 index 0000000..1d1a16d --- /dev/null +++ b/test/type_checker/test_for.ml @@ -0,0 +1,25 @@ +open OUnit2 +open Fly_lib + +let check_program source_code = + let lexbuf = Lexing.from_string source_code in + let ast = Parser.program_rule Scanner.tokenize lexbuf in + try + let _ = Semant.check ast.body in + "" + with + | Failure msg -> msg +;; + +let tests = + "testing_loops" + >::: [ ("correct_list" + >:: fun _ -> + let actual = check_program "fun function() -> int { for y := [10, 20, 30] { return 0; } return 1; }" in + let expected = "" in + assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) + + ] +;; + +let _ = run_test_tt_main tests diff --git a/test/type_checker/test_loops.ml b/test/type_checker/test_loops.ml index 20cc70d..6da0da6 100644 --- a/test/type_checker/test_loops.ml +++ b/test/type_checker/test_loops.ml @@ -18,11 +18,6 @@ let tests = let actual = check_program "let x := [1,2,3,4,5]; for i := x {let a := i;}" in let expected = "" in assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) - ; ("correct_tuple" - >:: fun _ -> - let actual = check_program "let x := (1,2,3,4,5); for i := x {let a := i;}" in - let expected = "" in - assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) ; ("loop_var_shadowing" >:: fun _ -> let actual = check_program "let i := 0; let x := [1,2,3,4,5]; for i := x {}" in