diff --git a/bin/fly.ml b/bin/fly.ml index f49989c..80516e3 100644 --- a/bin/fly.ml +++ b/bin/fly.ml @@ -53,7 +53,8 @@ let read_and_print_ir channel = let lexbuf = Lexing.from_channel channel in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Semant.check ast.body in - let md = Irgen.translate sast in + let unbound_sast = Unbind.unbind sast in + let md = Irgen.translate unbound_sast in print_endline (L.string_of_llmodule md) ;; @@ -61,7 +62,8 @@ let read_and_compile channel = let lexbuf = Lexing.from_channel channel in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Semant.check ast.body in - let md = Irgen.translate sast in + let unbound_sast = Unbind.unbind sast in + let md = Irgen.translate unbound_sast in (* Inititalize triples that llvm needs to create a target *) Llvm_all_backends.initialize (); diff --git a/lib/irgen.ml b/lib/irgen.ml index ad47617..b20b464 100644 --- a/lib/irgen.ml +++ b/lib/irgen.ml @@ -131,13 +131,6 @@ let define_udt_type name members = L.struct_set_body struct_ll_type field_ll_types_array false ;; -(* let define_udt_type name members = *) -(* let field_types = List.map (fun (_, t) -> ltype_of_typ t) members in *) -(* let struct_type = L.struct_type context (Array.of_list field_types) in *) -(* Hashtbl.add udt_structs name struct_type; *) -(* Hashtbl.add udt_field_indices name (List.mapi (fun i (name, _) -> name, i) members) *) -(* ;; *) - let build_udt_access typ var_name field_name vars builder = let struct_ptr = lookup_value vars var_name in let var_typ = @@ -217,6 +210,7 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build | A.Postincr | A.Postdecr -> ll_original_val | A.Preincr | A.Predecr -> ll_new_val | _ -> failwith "Could apply incr/decr to variable") + (* In build_expr, replacing the SFunctionCall case *) | SFunctionCall (func_name, actual_s_exprs_list) -> if func_name = print_func_name then prelude_print (func_name, actual_s_exprs_list) vars var_types the_module builder @@ -229,46 +223,81 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build match L.lookup_function func_name the_module with | Some f -> f | None -> - failwith - (Printf.sprintf - "IRgen build_expr: Function '%s' not found in LLVM module. (Was it \ - declared in Pass 1?)" - func_name) + raise (Failure (Printf.sprintf "IRgen: Function '%s' not declared" func_name)) in - let exp_formals, exp_ret_typ = + let sast_expected_formals, sast_expected_return_typ = try StringMap.find func_name !function_signatures with | Not_found -> - failwith - (Printf.sprintf - "IRgen build_expr: SAST signature for function '%s' not found." - func_name) + raise + (Failure (Printf.sprintf "IRgen: SAST signature for '%s' not found" func_name)) in - if List.length exp_formals <> List.length actual_s_exprs_list + if List.length sast_expected_formals <> List.length actual_s_exprs_list then failwith (Printf.sprintf - "IRgen build_expr: Arity mismatch for function '%s'. Expected %d args, got \ - %d." + "IRgen: Arity mismatch for %s. Expected %d, Got %d" func_name - (List.length exp_formals) + (List.length sast_expected_formals) (List.length actual_s_exprs_list)); - let evaluated_ll_args_list : L.llvalue list = - List.map2 - (fun (sexpr_arg : sexpr) (_formal_name, resolved_type) -> - let act_type = fst sexpr_arg in - assert_types act_type resolved_type; - build_expr sexpr_arg vars var_types the_module builder) + let callee_llvm_func_type = L.element_type (L.type_of callee_lfunc) in + let callee_llvm_param_types_array = L.param_types callee_llvm_func_type in + + let final_ll_args_for_call : L.llvalue list = + List.mapi + (fun i (actual_arg_sexpr : sexpr) -> + let _sast_formal_name, sast_formal_typ = List.nth sast_expected_formals i in + let sast_actual_arg_typ, _ = actual_arg_sexpr in + + assert_types sast_actual_arg_typ sast_formal_typ; + + let ll_val_from_arg_expr = + build_expr actual_arg_sexpr vars var_types the_module builder + in + let expected_llvm_param_type_in_callee = callee_llvm_param_types_array.(i) in + let type_of_ll_val_from_arg_expr = L.type_of ll_val_from_arg_expr in + + match sast_actual_arg_typ with + | RUserType _ -> + if + L.classify_type type_of_ll_val_from_arg_expr = L.TypeKind.Pointer + && L.classify_type expected_llvm_param_type_in_callee = L.TypeKind.Struct + then + L.build_load + ll_val_from_arg_expr + ("load_arg_" ^ _sast_formal_name) + builder + else if type_of_ll_val_from_arg_expr <> expected_llvm_param_type_in_callee + then + failwith + (Printf.sprintf + "IRgen SFunctionCall: UDT Arg LLVM type mismatch for %s. Got %s, \ + Callee expects %s" + _sast_formal_name + (L.string_of_lltype type_of_ll_val_from_arg_expr) + (L.string_of_lltype expected_llvm_param_type_in_callee)) + else ll_val_from_arg_expr + | _ -> + if type_of_ll_val_from_arg_expr <> expected_llvm_param_type_in_callee + then + failwith + (Printf.sprintf + "IRgen SFunctionCall: Non-UDT Arg LLVM type mismatch for %s. Got \ + %s, Callee expects %s" + _sast_formal_name + (L.string_of_lltype type_of_ll_val_from_arg_expr) + (L.string_of_lltype expected_llvm_param_type_in_callee)); + ll_val_from_arg_expr) actual_s_exprs_list - exp_formals in - let evaluated_ll_args_array : L.llvalue array = - Array.of_list evaluated_ll_args_list + let final_ll_args_array = Array.of_list final_ll_args_for_call in + let sast_expr_node_return_typ = fst expr in + assert_types sast_expr_node_return_typ sast_expected_return_typ; + + let call_result_name = + if sast_expected_return_typ = RUnit then "" else func_name ^ "_result" in - let ret_typ = fst expr in - assert_types ret_typ exp_ret_typ; - let call_result_name = if exp_ret_typ = RUnit then "" else func_name ^ "_result" in - L.build_call callee_lfunc evaluated_ll_args_array call_result_name builder) + L.build_call callee_lfunc final_ll_args_array call_result_name builder) | SEnumAccess (enum_name, variant_name) -> let key = extract_id_from_sexpr enum_name ^ "::" ^ variant_name in let vbl = diff --git a/lib/semant.ml b/lib/semant.ml index 30ef5b6..d92c04a 100644 --- a/lib/semant.ml +++ b/lib/semant.ml @@ -71,7 +71,7 @@ and var_dec_helper var_name t envs = ] in if List.exists (fun x -> x) env_checks - then raise (Failure (var_name ^ "already exists")) + then raise (Failure (var_name ^ " already exists")) else StringMap.add var_name t envs.var_env and func_def_helper func_name args rtyp envs = @@ -353,6 +353,7 @@ and check_expr expr envs special_blocks = | Some _ -> let func_sig = find_func (fst udt_func) envs.func_env in let _, def_arg_types = List.split func_sig.args in + let def_arg_types = List.tl def_arg_types in let sexpr_list = List.map (fun e -> check_expr e envs special_blocks) (snd udt_func) in @@ -366,7 +367,7 @@ and check_expr expr envs special_blocks = | None -> raise (Failure - (fst udt_func ^ "is not a method bound to " + (fst udt_func ^ " is not a method bound to " ^ string_of_resolved_type udt_typ)))) | UDTStaticAccess (udt_name, (func_name, args)) -> let udt_typ = find_udt udt_name envs.udt_env in @@ -539,7 +540,7 @@ and check_block block envs special_blocks func_ret_type = let updated_checked_func_body = update_func_body checked_func_body func_name is_unit rtyp envs in - ( updated_envs2 + ( updated_envs1 , updated_special_blocks , rtyp , SFunctionDefinition (rt, func_name, args, updated_checked_func_body) ) @@ -552,26 +553,26 @@ and check_block block envs special_blocks func_ret_type = let new_func_env = func_def_helper func_name args rt envs in (* add function name to environment *) let updated_envs1 = { envs with func_env = new_func_env } in + let new_udt_env = + add_bound_func_def func_name (string_of_resolved_type bound_type) envs + in + let updated_envs2 = { updated_envs1 with udt_env = new_udt_env } in let new_var_env = add_func_args args updated_envs1 in (* add function arguments to environment *) - let updated_envs2 = { updated_envs1 with var_env = new_var_env } in + let updated_envs3 = { updated_envs2 with var_env = new_var_env } in let updated_special_blocks = if rtyp = Unit then StringSet.add "ReturnUnit" special_blocks else StringSet.add "ReturnVal" special_blocks in let checked_func_body = - check_block_list func_body updated_envs2 updated_special_blocks rtyp + check_block_list func_body updated_envs3 updated_special_blocks rtyp in let is_unit = rtyp = Unit in let updated_checked_func_body = update_func_body checked_func_body func_name is_unit rtyp envs in - let new_udt_env = - add_bound_func_def func_name (string_of_resolved_type bound_type) envs - in - let updated_envs3 = { updated_envs2 with udt_env = new_udt_env } in - ( updated_envs3 + ( updated_envs2 , updated_special_blocks , rtyp , SBoundFunctionDefinition (rt, func_name, args, updated_checked_func_body, bound_type) diff --git a/lib/unbind.ml b/lib/unbind.ml new file mode 100644 index 0000000..8c1ea14 --- /dev/null +++ b/lib/unbind.ml @@ -0,0 +1,173 @@ +open Sast +module StringSet = Set.Make (String) + +let rec add_func_args args vars = + match args with + | [] -> vars + | curr :: rest -> + let new_vars = StringSet.add (fst curr) vars in + add_func_args rest new_vars + +and fresh_var_name base_name vars counter = + let candidate = base_name ^ string_of_int counter in + if StringSet.mem candidate vars + then fresh_var_name base_name vars (counter + 1) + else candidate + +and unbind_sexpr se replace_self new_var_name = + let t', se' = se in + match se' with + | SLiteral _ + | SBoolLit _ + | SFloatLit _ + | SCharLit _ + | SStringLit _ + | SUnit + | SEnumAccess _ -> t', se' + | SUnopSideEffect _ | SMatch _ | SWildcard -> failwith "Dropping" + | SBinop (se1, binop, se2) -> + let unbound_se1 = unbind_sexpr se1 replace_self new_var_name in + let unbound_se2 = unbind_sexpr se2 replace_self new_var_name in + t', SBinop (unbound_se1, binop, unbound_se2) + | SUnop (se, unop) -> + let unbound_se = unbind_sexpr se replace_self new_var_name in + t', SUnop (unbound_se, unop) + | SFunctionCall (func_name, func_args) -> + let unbound_args = + List.map (fun arg -> unbind_sexpr arg replace_self new_var_name) func_args + in + t', SFunctionCall (func_name, unbound_args) + | SId id -> if replace_self && id = "self" then t', SId new_var_name else t', SId id + | STuple se_list -> + let unbound_se_list = + List.map (fun elem -> unbind_sexpr elem replace_self new_var_name) se_list + in + t', STuple unbound_se_list + | SUDTInstance (udt_name, udt_members) -> + let unbound_udt_members = + List.map + (fun member -> fst member, unbind_sexpr (snd member) replace_self new_var_name) + udt_members + in + t', SUDTInstance (udt_name, unbound_udt_members) + | SUDTAccess (udt_se, udt_member) -> + let unbound_udt_se = unbind_sexpr udt_se replace_self new_var_name in + (match udt_member with + | SUDTVariable x -> t', SUDTAccess (unbound_udt_se, SUDTVariable x) + | SUDTFunction (udt_func_name, udt_func_params) -> + let unbound_params1 = + List.map + (fun param -> unbind_sexpr param replace_self new_var_name) + udt_func_params + in + let unbound_params2 = unbound_params1 @ [ unbound_udt_se ] in + t', SFunctionCall (udt_func_name, unbound_params2)) + | SUDTStaticAccess (udt_name, udt_static_func) -> + t', SUDTStaticAccess (udt_name, udt_static_func) + | SIndex (indexed_se, index_val) -> + let unbound_indexed_se = unbind_sexpr indexed_se replace_self new_var_name in + let unbound_index_val = unbind_sexpr index_val replace_self new_var_name in + t', SIndex (unbound_indexed_se, unbound_index_val) + | SList se_list -> + let unbound_se_list = + List.map (fun elem -> unbind_sexpr elem replace_self new_var_name) se_list + in + t', SList unbound_se_list + | STypeCast (new_rt, target_se) -> + let unbound_target_se = unbind_sexpr target_se replace_self new_var_name in + t', STypeCast (new_rt, unbound_target_se) + +and unbind_block sblk variables replace_self new_var_name = + match sblk with + | SMutDeclTyped _ | SAssign _ -> failwith "Dropping" + | SDeclTyped (var_name, rt, se) -> + let updated_variables = StringSet.add var_name variables in + updated_variables, SDeclTyped (var_name, rt, unbind_sexpr se replace_self new_var_name) + | SFunctionDefinition (rt, func_name, func_args, body) -> + let updated_variables1 = add_func_args func_args variables in + let updated_variables2, unbound_body = + unbind_block_list body updated_variables1 replace_self new_var_name + in + updated_variables2, SFunctionDefinition (rt, func_name, func_args, unbound_body) + | SBoundFunctionDefinition (rt, func_name, func_args, body, _) -> + let updated_variables1 = add_func_args func_args variables in + let updated_variables2, _ = + unbind_block_list body updated_variables1 replace_self new_var_name + in + let var_name = fresh_var_name "tmp" updated_variables2 0 in + let updated_func_args = + List.map + (fun (arg_name, arg_type) -> + if arg_name = "self" then var_name, arg_type else arg_name, arg_type) + func_args + in + let updated_variables3 = add_func_args updated_func_args variables in + let _, unbound_body = unbind_block_list body updated_variables3 true var_name in + variables, SFunctionDefinition (rt, func_name, updated_func_args, unbound_body) + | SEnumDeclaration (enum_name, enum_variants) -> + variables, SEnumDeclaration (enum_name, enum_variants) + | SUDTDef (udt_name, udt_members) -> variables, SUDTDef (udt_name, udt_members) + | SIfEnd (cond, if_body) -> + let unbound_cond = unbind_sexpr cond replace_self new_var_name in + variables, SIfEnd (unbound_cond, if_body) + | SIfNonEnd (cond, if_body, other) -> + let unbound_cond = unbind_sexpr cond replace_self new_var_name in + let updated_variables1, unbound_if_body = + unbind_block_list if_body variables replace_self new_var_name + in + let updated_variables2, unbound_other = + unbind_block other updated_variables1 replace_self new_var_name + in + updated_variables2, SIfNonEnd (unbound_cond, unbound_if_body, unbound_other) + | SElifNonEnd (cond, elif_body, other) -> + let unbound_cond = unbind_sexpr cond replace_self new_var_name in + let updated_variables1, unbound_elif_body = + unbind_block_list elif_body variables replace_self new_var_name + in + let updated_variables2, unbound_other = + unbind_block other updated_variables1 replace_self new_var_name + in + updated_variables2, SElifNonEnd (unbound_cond, unbound_elif_body, unbound_other) + | SElifEnd (cond, elif_body) -> + let unbound_cond = unbind_sexpr cond replace_self new_var_name in + let updated_variables, unbound_elif_body = + unbind_block_list elif_body variables replace_self new_var_name + in + updated_variables, SElifEnd (unbound_cond, unbound_elif_body) + | SElseEnd else_body -> + let updated_variables, unbound_else_body = + unbind_block_list else_body variables replace_self new_var_name + in + updated_variables, SElseEnd unbound_else_body + | SWhile (se, while_body) -> + let unbound_se = unbind_sexpr se replace_self new_var_name in + variables, SWhile (unbound_se, while_body) + | SFor (iterator, iterable_se, for_body) -> + let unbound_iterable_se = unbind_sexpr iterable_se replace_self new_var_name in + let updated_variables, unbound_for_body = + unbind_block_list for_body variables replace_self new_var_name + in + updated_variables, SFor (iterator, unbound_iterable_se, unbound_for_body) + | SBreak -> variables, SBreak + | SContinue -> variables, SContinue + | SReturnUnit -> variables, SReturnUnit + | SReturnVal se -> variables, SReturnVal (unbind_sexpr se replace_self new_var_name) + | SExpr se -> variables, SExpr (unbind_sexpr se replace_self new_var_name) + +and unbind_block_list sblk_list variables replace_self var_name = + match sblk_list with + | [] -> variables, [] + | sblk :: rest -> + let updated_variables, unbound_sblk = + unbind_block sblk variables replace_self var_name + in + ( updated_variables + , unbound_sblk :: snd (unbind_block_list rest updated_variables replace_self var_name) + ) +;; + +let unbind sblk_list = + let variables = StringSet.empty in + let _, unbound_sblk_list = unbind_block_list sblk_list variables false "" in + unbound_sblk_list +;; diff --git a/test/ir/test_cond.ml b/test/ir/test_cond.ml index e4ea4e7..952d457 100644 --- a/test/ir/test_cond.ml +++ b/test/ir/test_cond.ml @@ -5,9 +5,10 @@ module L = Llvm let get_sast input = try let lexbuf = Lexing.from_string input in - let ast = Parser.program_rule Scanner.tokenize lexbuf in - let sast = Semant.check ast.body in - sast + let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in + let sast = Fly_lib.Semant.check ast.body in + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> failwith diff --git a/test/ir/test_enums.ml b/test/ir/test_enums.ml index 3698ea4..c38ed9c 100644 --- a/test/ir/test_enums.ml +++ b/test/ir/test_enums.ml @@ -7,7 +7,8 @@ let get_sast input = let lexbuf = Lexing.from_string input in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Fly_lib.Semant.check ast.body in - sast + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/ir/test_ir.ml b/test/ir/test_ir.ml index 90f6d91..e473cfc 100644 --- a/test/ir/test_ir.ml +++ b/test/ir/test_ir.ml @@ -5,10 +5,10 @@ module L = Llvm let get_sast input = try let lexbuf = Lexing.from_string input in - let ast = Parser.program_rule Scanner.tokenize lexbuf in - (* Make sure these module names are correct *) - let sast = Semant.check ast.body in - sast + let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in + let sast = Fly_lib.Semant.check ast.body in + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/ir/test_list.ml b/test/ir/test_list.ml index 2046486..428a6f2 100644 --- a/test/ir/test_list.ml +++ b/test/ir/test_list.ml @@ -7,7 +7,8 @@ let get_sast input = let lexbuf = Lexing.from_string input in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Fly_lib.Semant.check ast.body in - sast + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/ir/test_prelude.ml b/test/ir/test_prelude.ml index be64c57..b1664d0 100644 --- a/test/ir/test_prelude.ml +++ b/test/ir/test_prelude.ml @@ -9,7 +9,8 @@ let get_sast input = let lexbuf = Lexing.from_string input in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Fly_lib.Semant.check ast.body in - sast + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/ir/test_string.ml b/test/ir/test_string.ml index eb2643c..956478c 100644 --- a/test/ir/test_string.ml +++ b/test/ir/test_string.ml @@ -7,7 +7,8 @@ let get_sast input = let lexbuf = Lexing.from_string input in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Fly_lib.Semant.check ast.body in - sast + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/ir/test_udt.ml b/test/ir/test_udt.ml index 9ee43fe..d2a2ea5 100644 --- a/test/ir/test_udt.ml +++ b/test/ir/test_udt.ml @@ -5,9 +5,10 @@ module L = Llvm let get_sast input = try let lexbuf = Lexing.from_string input in - let ast = Parser.program_rule Scanner.tokenize lexbuf in - let sast = Semant.check ast.body in - sast + let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in + let sast = Fly_lib.Semant.check ast.body in + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/ir/test_vars.ml b/test/ir/test_vars.ml index 7754ad6..093278e 100644 --- a/test/ir/test_vars.ml +++ b/test/ir/test_vars.ml @@ -7,7 +7,8 @@ let get_sast input = let lexbuf = Lexing.from_string input in let ast = Fly_lib.Parser.program_rule Fly_lib.Scanner.tokenize lexbuf in let sast = Fly_lib.Semant.check ast.body in - sast + let unbound_sast = Fly_lib.Unbind.unbind sast in + unbound_sast with | err -> raise diff --git a/test/type_checker/test_bind.ml b/test/type_checker/test_bind.ml new file mode 100644 index 0000000..f7ed71d --- /dev/null +++ b/test/type_checker/test_bind.ml @@ -0,0 +1,64 @@ +open OUnit2 +open Fly_lib + +let check_program source_code = + let lexbuf = Lexing.from_string source_code in + let ast = Parser.program_rule Scanner.tokenize lexbuf in + try + let _ = Semant.check ast.body in + "" + with + | Failure msg -> msg +;; + +let tests = + "testing_return" + >::: [ ("static_bind" + >:: fun _ -> + let actual = + check_program + "type Person {name: string, age: int} bind new(name: string, age: \ + int) -> Person { return Person {name: name, age: age}; }" + in + let expected = "" in + assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) + ; ("non_static_bind" + >:: fun _ -> + let actual = + check_program + "type Person {name: string, age: int} bind get_age(self) -> int { \ + return self.age; }" + in + let expected = "" in + assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) + ; ("primitive_bind" + >:: fun _ -> + let actual = + check_program + "bind add(self, other: int) -> int { return self + other; }" + in + let expected = "" in + assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) + ; ("non_primitive_bind" + >:: fun _ -> + let actual = + check_program + "bind search>(self, target: value) -> bool \n\ + \ { \n\ + for v := self { \n\ + \ if (v == target) {\n\ + \ return true;\n\ + \ }\n\ + \ }\n\ + \ return false;\n\ + \ }\n\n\ + \ let x := [1,2,3,4,5];\n\ + \ let res := x.search(3);\n\ + \ " + in + let expected = "" in + assert_equal expected actual ~printer:(fun s -> "\"" ^ s ^ "\"")) + ] +;; + +let _ = run_test_tt_main tests