From dad48e2ead478f831abd65844e4cd3504b828133 Mon Sep 17 00:00:00 2001 From: AlexZhu2 <52172864+AlexZhu2@users.noreply.github.com> Date: Fri, 16 May 2025 14:07:20 -0400 Subject: [PATCH 1/2] list + udt/enum --- lib/ast.ml | 4 ++-- lib/irgen.ml | 49 +++++++++++++++++++++++++++++++++------------ lib/parser.mly | 6 +++--- lib/sast.ml | 4 ++-- lib/semant.ml | 35 ++++++++++++++++++++++++-------- lib/utils/prints.ml | 11 ++++++---- 6 files changed, 77 insertions(+), 32 deletions(-) diff --git a/lib/ast.ml b/lib/ast.ml index a683674..180cabc 100644 --- a/lib/ast.ml +++ b/lib/ast.ml @@ -64,9 +64,9 @@ type expr = | UnopSideEffect of string * op (* this is for postincr, postdecr, preincr, postdecr *) | FunctionCall of func | UDTInstance of string * kv_list - | UDTAccess of string * udt_access + | UDTAccess of expr * udt_access | UDTStaticAccess of string * func - | EnumAccess of string * string + | EnumAccess of expr * string | Index of expr * expr | List of expr list | Match of expr * (pattern * expr) list diff --git a/lib/irgen.ml b/lib/irgen.ml index 8898146..c1c61ca 100644 --- a/lib/irgen.ml +++ b/lib/irgen.ml @@ -46,6 +46,14 @@ let int_format_str builder = L.build_global_stringptr "%d\n" "int_fmt" builder let str_format_str builder = L.build_global_stringptr "%s\n" "str_fmt" builder let float_format_str builder = L.build_global_stringptr "%f\n" "float_fmt" builder +let rec extract_id_from_sexpr (sexpr : sexpr) : string = + match snd sexpr with + | SId id -> id + | SUDTAccess (base, _) -> extract_id_from_sexpr base + | SIndex (base, _) -> extract_id_from_sexpr base + | _ -> raise (Failure "Expected an identifier or access expression") +;; + (* Creates a binding to the llvm libc "printf" function *) let l_printf : L.lltype = L.var_arg_function_type l_int [| l_str |] let print_func the_module : L.llvalue = L.declare_function "printf" l_printf the_module @@ -123,14 +131,14 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build | SFloatLit f -> L.const_float l_float f | SId var -> let vbl = lookup vars var in - (* strings are pointers, they should not be load-ed like other variables - Now, local strings already exist in a variable in the function scope, - and build_load is okay here as we're loading from the variable, not - the raw pointer. - Therefore, we have this special case for Global strings - *) if vbl.v_scope == Global && vbl.v_type == RString then vbl.v_value + else if + vbl.v_type + |> function + | RUserType _ -> true + | _ -> false + then vbl.v_value (* For structs, return the pointer, not a load *) else L.build_load vbl.v_value var builder | SUnop (e, op) -> let llval = build_expr e vars var_types the_module builder in @@ -189,14 +197,14 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build then prelude_input func vars var_types the_module builder else raise (Failure "function calls not implemented") | SEnumAccess (enum_name, variant_name) -> - let key = enum_name ^ "::" ^ variant_name in + let key = extract_id_from_sexpr enum_name ^ "::" ^ variant_name in let vbl = try StringMap.find key vars with | Not_found -> failwith (Printf.sprintf "IRgen: Enum variant %s::%s not found in vars map during SEnumAccess" - enum_name + (extract_id_from_sexpr enum_name) variant_name) in if L.type_of vbl.v_value <> L.pointer_type l_int @@ -294,17 +302,23 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build fields; instance | SUDTAccess (id, SUDTVariable field) -> - let struct_ptr = lookup_value vars id in - let id_typ = lookup_type var_types id in + let struct_ptr = build_expr id vars var_types the_module builder in + let id_typ = fst id in let type_name = match id_typ with | RUserType n -> n - | _ -> raise (Failure ("Expected user type for variable: " ^ id)) + | _ -> + raise (Failure ("Expected user type for variable: " ^ extract_id_from_sexpr id)) in + let field_indices = Hashtbl.find udt_field_indices type_name in let idx = List.assoc field field_indices in - let field_ptr = L.build_struct_gep struct_ptr idx (id ^ "_" ^ field) builder in - let field_val = L.build_load field_ptr (field ^ "_val") builder in + let field_ptr = + L.build_struct_gep struct_ptr idx (extract_id_from_sexpr id ^ "_" ^ field) builder + in + let field_val = + L.build_load field_ptr (extract_id_from_sexpr id ^ "_" ^ field ^ "_val") builder + in field_val | SList list -> let typ = fst (List.hd list) in @@ -320,6 +334,15 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build list; llist | SStringLit s -> L.build_global_stringptr s "str" builder + | SIndex (list_expr, index_expr) -> + let list_val = build_expr list_expr vars var_types the_module builder in + let index_val = build_expr index_expr vars var_types the_module builder in + let elem_ptr = L.build_gep list_val [| index_val |] "elem_ptr" builder in + (match fst expr with + | RInt | RFloat | RBool | RChar | REnumType _ -> + L.build_load elem_ptr "elem_val" builder + | RUserType _ -> elem_ptr + | _ -> failwith "Unsupported list element type for indexing") | e -> raise (Failure (Printf.sprintf "expr not implemented: %s" (Utils.string_of_sexpr e))) diff --git a/lib/parser.mly b/lib/parser.mly index d9c5529..1ed532a 100644 --- a/lib/parser.mly +++ b/lib/parser.mly @@ -66,10 +66,10 @@ side_effect_expr: | DECR ID { UnopSideEffect($2, Predecr) } access_expr: - | ID DOT udt_access { UDTAccess($1, $3) } + | expr DOT udt_access { UDTAccess($1, $3) } | ID DCOLON func_call { UDTStaticAccess($1, $3) } - | SELF DOT udt_access { UDTAccess ("self", $3) } - | ID DCOLON ID { EnumAccess($1, $3) } + | SELF DOT udt_access { UDTAccess (Id("self"), $3) } + | ID DCOLON ID { EnumAccess(Id($1), $3) } | expr LBRACKET expr RBRACKET { Index($1, $3) } udt_access: diff --git a/lib/sast.ml b/lib/sast.ml index 484e36c..b1f0e27 100644 --- a/lib/sast.ml +++ b/lib/sast.ml @@ -30,9 +30,9 @@ and sx = | SUnopSideEffect of string * op (* this is for postincr, postdecr, preincr, postdecr *) | SFunctionCall of sfunc | SUDTInstance of string * skv_list - | SUDTAccess of string * sudt_access + | SUDTAccess of sexpr * sudt_access | SUDTStaticAccess of string * sfunc - | SEnumAccess of string * string + | SEnumAccess of sexpr * string | SIndex of sexpr * sexpr | SList of sexpr list | SMatch of sexpr * (pattern * sexpr) list diff --git a/lib/semant.ml b/lib/semant.ml index 74a8163..5e115c0 100644 --- a/lib/semant.ml +++ b/lib/semant.ml @@ -234,8 +234,15 @@ and check_expr expr envs special_blocks = | StringLit s -> RString, SStringLit s | Unit -> RUnit, SUnit | Id id_name -> - let t = find_var id_name envs.var_env in - t, SId id_name + if StringMap.mem id_name envs.var_env + then ( + let t = find_var id_name envs.var_env in + t, SId id_name) + else if StringMap.mem id_name envs.enum_env + then REnumType id_name, SId id_name + else if StringMap.mem id_name envs.udt_env + then RUserType id_name, SId id_name + else raise (Failure ("Undeclared variable or type " ^ id_name)) | Tuple expr_list -> let sexpr_list = List.map (fun e -> check_expr e envs special_blocks) expr_list in let typs, _ = List.split sexpr_list in @@ -276,13 +283,14 @@ and check_expr expr envs special_blocks = let t = find_func func_name envs.func_env in (* t is return type of this function call *) t.rtyp, SFunctionCall (func_name, sfunc_args) - | UDTAccess (id_name, udt_accessed_member) -> - let udt_typ = find_var id_name envs.var_env in + | UDTAccess (udt_expr, udt_accessed_member) -> + let udt_typ, checked_udt_expr = check_expr udt_expr envs special_blocks in let udt_def = find_udt (string_of_resolved_type udt_typ) envs.udt_env in (match udt_accessed_member with | UDTVariable udt_var -> (match List.assoc_opt udt_var udt_def.members with - | Some accessed_type -> accessed_type, SUDTAccess (id_name, SUDTVariable udt_var) + | Some accessed_type -> + accessed_type, SUDTAccess ((udt_typ, checked_udt_expr), SUDTVariable udt_var) | None -> raise (Failure (udt_var ^ "is not in " ^ string_of_resolved_type udt_typ))) | UDTFunction udt_func -> @@ -295,7 +303,10 @@ and check_expr expr envs special_blocks = in let arg_types, _ = List.split sexpr_list in if arg_types = def_arg_types - then func_sig.rtyp, SUDTAccess (id_name, SUDTFunction (fst udt_func, sexpr_list)) + then + ( func_sig.rtyp + , SUDTAccess + ((udt_typ, checked_udt_expr), SUDTFunction (fst udt_func, sexpr_list)) ) else raise (Failure "Incorrect types passed to this method") | None -> raise @@ -314,7 +325,15 @@ and check_expr expr envs special_blocks = then func_sig.rtyp, SUDTStaticAccess (udt_name, (func_name, sexpr_list)) else raise (Failure "Incorrect types passed to this method") | None -> raise (Failure (func_name ^ "is not a method bound to " ^ udt_name))) - | EnumAccess (enum_name, variant) -> + | EnumAccess (enum_expr, variant) -> + let t_enum, checked_enum_expr = check_expr enum_expr envs special_blocks in + let enum_name = + match enum_expr with + | Id name -> name + | _ -> + Printf.eprintf "EnumAccess base: %s\n" (Utils.string_of_expr enum_expr); + failwith "EnumAccess base must be an identifier" + in let enum_variants = try StringMap.find enum_name envs.enum_env with | Not_found -> raise (Failure ("Undefined enum " ^ enum_name)) @@ -328,7 +347,7 @@ and check_expr expr envs special_blocks = in if not variant_exists then raise (Failure ("Undefined variant " ^ variant ^ " in enum " ^ enum_name)) - else REnumType enum_name, SEnumAccess (enum_name, variant) + else REnumType enum_name, SEnumAccess ((t_enum, checked_enum_expr), variant) | Index (e1, e2) -> let t1, e1' = check_expr e1 envs special_blocks in let t2, e2' = check_expr e2 envs special_blocks in diff --git a/lib/utils/prints.ml b/lib/utils/prints.ml index 54ec52d..7a36801 100644 --- a/lib/utils/prints.ml +++ b/lib/utils/prints.ml @@ -165,7 +165,8 @@ let rec string_of_expr = function func_name ^ "(" ^ String.concat ", " (List.map string_of_expr func_args) ^ ")" | UDTInstance (udt_name, udt_members) -> udt_name ^ "{" ^ string_of_udt_instance udt_members ^ "}" - | UDTAccess (udt_name, udt_access) -> udt_name ^ "." ^ string_of_udt_access udt_access + | UDTAccess (udt_expr, udt_access) -> + string_of_expr udt_expr ^ "." ^ string_of_udt_access udt_access | UDTStaticAccess (udt_name, udt_function) -> udt_name ^ "::" ^ fst udt_function ^ "(" ^ String.concat ", " (List.map string_of_expr (snd udt_function)) @@ -175,7 +176,7 @@ let rec string_of_expr = function | Match (e1, case_list) -> "match (" ^ string_of_expr e1 ^ ") {\n" ^ string_of_case_list case_list ^ "}" | Wildcard -> "_" - | EnumAccess (enum_name, enum_variant) -> enum_name ^ "::" ^ enum_variant + | EnumAccess (enum_expr, enum_variant) -> string_of_expr enum_expr ^ "::" ^ enum_variant | TypeCast (type_name, e) -> string_of_expr e ^ " as " ^ string_of_type type_name and string_of_pattern = function @@ -323,7 +324,8 @@ let rec string_of_sexpr = function ^ ")" | SUDTInstance (udt_name, udt_members) -> udt_name ^ "{" ^ string_of_sudt_instance udt_members ^ "}" - | SUDTAccess (udt_name, udt_access) -> udt_name ^ "." ^ string_of_udt_access udt_access + | SUDTAccess (udt_expr, udt_access) -> + string_of_sexpr (snd udt_expr) ^ "." ^ string_of_udt_access udt_access | SUDTStaticAccess (udt_name, udt_function) -> udt_name ^ "::" ^ fst udt_function ^ "(" ^ String.concat ", " (List.map string_of_sexpr (List.map snd (snd udt_function))) @@ -333,7 +335,8 @@ let rec string_of_sexpr = function | SMatch (e1, case_list) -> "match (" ^ string_of_sexpr (snd e1) ^ ") {\n" ^ string_of_scase_list case_list ^ "}" | SWildcard -> "_" - | SEnumAccess (enum_name, enum_variant) -> enum_name ^ "::" ^ enum_variant + | SEnumAccess (enum_expr, enum_variant) -> + string_of_sexpr (snd enum_expr) ^ "::" ^ enum_variant | STypeCast (type_name, e) -> string_of_sexpr (snd e) ^ " as " ^ string_of_resolved_type type_name From 4f0eb2843f300247c3adaa1ba910dc326ce04fbb Mon Sep 17 00:00:00 2001 From: AlexZhu2 <52172864+AlexZhu2@users.noreply.github.com> Date: Fri, 16 May 2025 15:00:29 -0400 Subject: [PATCH 2/2] added tests --- test/ir/dune | 2 +- test/ir/test_list.ml | 67 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/test/ir/dune b/test/ir/dune index ee9b955..22d382e 100644 --- a/test/ir/dune +++ b/test/ir/dune @@ -28,4 +28,4 @@ (test (name test_list) - (libraries fly_lib ounit2)) + (libraries fly_lib ounit2 str)) diff --git a/test/ir/test_list.ml b/test/ir/test_list.ml index 2773e80..1c8b4fe 100644 --- a/test/ir/test_list.ml +++ b/test/ir/test_list.ml @@ -123,6 +123,73 @@ let tests = in (* _write_to_file actual "actual.out"; *) assert_equal expected actual ~printer) + ; ("list_of_structs_index_field" + >:: fun _ -> + let sast = + get_sast + "type Point { x:int, y:int } fun main() -> int { let p1 := Point{x:1, \ + y:2}; let p2 := Point{x:3, y:4}; let arr := [p1, p2]; return arr[1].y; }" + in + let mdl = Irgen.translate sast in + let actual = L.string_of_llmodule mdl in + let expected = + "; ModuleID = 'Fly'\n\ + source_filename = \"Fly\"\n\n\ + define i32 @main() {\n\ + entry:\n\ + \ %Point_inst = alloca { i32, i32 }, align 8\n\ + \ %Point_x = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst, i32 0, i32 0\n\ + \ store i32 1, i32* %Point_x, align 4\n\ + \ %Point_y = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst, i32 0, i32 1\n\ + \ store i32 2, i32* %Point_y, align 4\n\ + \ %Point_inst1 = alloca { i32, i32 }, align 8\n\ + \ %Point_x2 = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst1, i32 0, i32 0\n\ + \ store i32 3, i32* %Point_x2, align 4\n\ + \ %Point_y3 = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst1, i32 0, i32 1\n\ + \ store i32 4, i32* %Point_y3, align 4\n\ + \ %list = alloca { i32, i32 }, i32 2, align 8\n\ + \ %index = getelementptr inbounds { i32, i32 }, { i32, i32 }* %list, i32 0\n\ + \ store { i32, i32 }* %Point_inst, { i32, i32 }* %index, align 8\n\ + \ %index4 = getelementptr inbounds { i32, i32 }, { i32, i32 }* %list, i32 1\n\ + \ store { i32, i32 }* %Point_inst1, { i32, i32 }* %index4, align 8\n\ + \ %arr = alloca { i32, i32 }*, align 8\n\ + \ store { i32, i32 }* %list, { i32, i32 }** %arr, align 8\n\ + \ %arr5 = load { i32, i32 }*, { i32, i32 }** %arr, align 8\n\ + \ %elem_ptr = getelementptr { i32, i32 }, { i32, i32 }* %arr5, i32 1\n\ + \ %arr_y = getelementptr inbounds { i32, i32 }, { i32, i32 }* %elem_ptr, \ + i32 0, i32 1\n\ + \ %arr_y_val = load i32, i32* %arr_y, align 4\n\ + \ ret i32 %arr_y_val\n\ + }\n" + in + assert_equal expected actual ~printer) + ; ("list_of_enums_index" + >:: fun _ -> + let sast = + get_sast + "enum Color { Red, Green, Blue } fun main() -> Color { let arr := \ + [Color::Red, Color::Blue]; return arr[1]; }" + in + let mdl = Irgen.translate sast in + let actual = L.string_of_llmodule mdl in + assert_bool + "alloca for enum list" + (try + ignore (Str.search_forward (Str.regexp "alloca i32, i32 2") actual 0); + true + with + | Not_found -> false); + assert_bool + "return enum value" + (try + ignore (Str.search_forward (Str.regexp "ret i32") actual 0); + true + with + | Not_found -> false)) ] ;;