diff --git a/lib/ast.ml b/lib/ast.ml index a683674..180cabc 100644 --- a/lib/ast.ml +++ b/lib/ast.ml @@ -64,9 +64,9 @@ type expr = | UnopSideEffect of string * op (* this is for postincr, postdecr, preincr, postdecr *) | FunctionCall of func | UDTInstance of string * kv_list - | UDTAccess of string * udt_access + | UDTAccess of expr * udt_access | UDTStaticAccess of string * func - | EnumAccess of string * string + | EnumAccess of expr * string | Index of expr * expr | List of expr list | Match of expr * (pattern * expr) list diff --git a/lib/irgen.ml b/lib/irgen.ml index 8f273cd..d02f06d 100644 --- a/lib/irgen.ml +++ b/lib/irgen.ml @@ -46,6 +46,14 @@ let int_format_str builder = L.build_global_stringptr "%d\n" "int_fmt" builder let str_format_str builder = L.build_global_stringptr "%s\n" "str_fmt" builder let float_format_str builder = L.build_global_stringptr "%f\n" "float_fmt" builder +let rec extract_id_from_sexpr (sexpr : sexpr) : string = + match snd sexpr with + | SId id -> id + | SUDTAccess (base, _) -> extract_id_from_sexpr base + | SIndex (base, _) -> extract_id_from_sexpr base + | _ -> raise (Failure "Expected an identifier or access expression") +;; + (* Creates a binding to the llvm libc "printf" function *) let l_printf : L.lltype = L.var_arg_function_type l_int [| l_str |] let print_func the_module : L.llvalue = L.declare_function "printf" l_printf the_module @@ -123,14 +131,14 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build | SFloatLit f -> L.const_float l_float f | SId var -> let vbl = lookup vars var in - (* strings are pointers, they should not be load-ed like other variables - Now, local strings already exist in a variable in the function scope, - and build_load is okay here as we're loading from the variable, not - the raw pointer. - Therefore, we have this special case for Global strings - *) if vbl.v_scope == Global && vbl.v_type == RString then vbl.v_value + else if + vbl.v_type + |> function + | RUserType _ -> true + | _ -> false + then vbl.v_value (* For structs, return the pointer, not a load *) else L.build_load vbl.v_value var builder | SUnop (e, op) -> let llval = build_expr e vars var_types the_module builder in @@ -189,14 +197,14 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build then prelude_input func vars var_types the_module builder else raise (Failure "function calls not implemented") | SEnumAccess (enum_name, variant_name) -> - let key = enum_name ^ "::" ^ variant_name in + let key = extract_id_from_sexpr enum_name ^ "::" ^ variant_name in let vbl = try StringMap.find key vars with | Not_found -> failwith (Printf.sprintf "IRgen: Enum variant %s::%s not found in vars map during SEnumAccess" - enum_name + (extract_id_from_sexpr enum_name) variant_name) in if L.type_of vbl.v_value <> L.pointer_type l_int @@ -294,17 +302,23 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build fields; instance | SUDTAccess (id, SUDTVariable field) -> - let struct_ptr = lookup_value vars id in - let id_typ = lookup_type var_types id in + let struct_ptr = build_expr id vars var_types the_module builder in + let id_typ = fst id in let type_name = match id_typ with | RUserType n -> n - | _ -> raise (Failure ("Expected user type for variable: " ^ id)) + | _ -> + raise (Failure ("Expected user type for variable: " ^ extract_id_from_sexpr id)) in + let field_indices = Hashtbl.find udt_field_indices type_name in let idx = List.assoc field field_indices in - let field_ptr = L.build_struct_gep struct_ptr idx (id ^ "_" ^ field) builder in - let field_val = L.build_load field_ptr (field ^ "_val") builder in + let field_ptr = + L.build_struct_gep struct_ptr idx (extract_id_from_sexpr id ^ "_" ^ field) builder + in + let field_val = + L.build_load field_ptr (extract_id_from_sexpr id ^ "_" ^ field ^ "_val") builder + in field_val | SList list -> let typ = fst (List.hd list) in @@ -324,12 +338,11 @@ let rec build_expr expr (vars : variable StringMap.t) var_types the_module build let list_val = build_expr list_expr vars var_types the_module builder in let index_val = build_expr index_expr vars var_types the_module builder in let elem_ptr = L.build_gep list_val [| index_val |] "elem_ptr" builder in - (match fst expr with | RInt | RFloat | RString | RBool | REnumType _ -> L.build_load elem_ptr "elem_val" builder | RUserType _ -> elem_ptr - | _ -> failwith "Unsupported list type for indexing") + | _ -> failwith "Unsupported list element type for indexing") | e -> raise (Failure (Printf.sprintf "expr not implemented: %s" (Utils.string_of_sexpr e))) diff --git a/lib/parser.mly b/lib/parser.mly index d9c5529..1ed532a 100644 --- a/lib/parser.mly +++ b/lib/parser.mly @@ -66,10 +66,10 @@ side_effect_expr: | DECR ID { UnopSideEffect($2, Predecr) } access_expr: - | ID DOT udt_access { UDTAccess($1, $3) } + | expr DOT udt_access { UDTAccess($1, $3) } | ID DCOLON func_call { UDTStaticAccess($1, $3) } - | SELF DOT udt_access { UDTAccess ("self", $3) } - | ID DCOLON ID { EnumAccess($1, $3) } + | SELF DOT udt_access { UDTAccess (Id("self"), $3) } + | ID DCOLON ID { EnumAccess(Id($1), $3) } | expr LBRACKET expr RBRACKET { Index($1, $3) } udt_access: diff --git a/lib/sast.ml b/lib/sast.ml index 484e36c..b1f0e27 100644 --- a/lib/sast.ml +++ b/lib/sast.ml @@ -30,9 +30,9 @@ and sx = | SUnopSideEffect of string * op (* this is for postincr, postdecr, preincr, postdecr *) | SFunctionCall of sfunc | SUDTInstance of string * skv_list - | SUDTAccess of string * sudt_access + | SUDTAccess of sexpr * sudt_access | SUDTStaticAccess of string * sfunc - | SEnumAccess of string * string + | SEnumAccess of sexpr * string | SIndex of sexpr * sexpr | SList of sexpr list | SMatch of sexpr * (pattern * sexpr) list diff --git a/lib/semant.ml b/lib/semant.ml index 7a397a2..30ef5b6 100644 --- a/lib/semant.ml +++ b/lib/semant.ml @@ -257,9 +257,9 @@ and update_func_body checked_func_body func_name is_unit rtyp envs = let enum_variant = List.hd (StringMap.find name envs.enum_env) in (match enum_variant with | EnumVariantDefault variant_name -> - REnumType name, SEnumAccess (name, variant_name) + REnumType name, SEnumAccess ((REnumType name, SId name), variant_name) | EnumVariantExplicit (variant_name, _) -> - REnumType name, SEnumAccess (name, variant_name)) + REnumType name, SEnumAccess ((REnumType name, SId name), variant_name)) | Sast.RUserType name -> let udt_info = StringMap.find name envs.udt_env in let udt_members = udt_info.members in @@ -289,8 +289,15 @@ and check_expr expr envs special_blocks = | StringLit s -> RString, SStringLit s | Unit -> RUnit, SUnit | Id id_name -> - let t = find_var id_name envs.var_env in - t, SId id_name + if StringMap.mem id_name envs.var_env + then ( + let t = find_var id_name envs.var_env in + t, SId id_name) + else if StringMap.mem id_name envs.enum_env + then REnumType id_name, SId id_name + else if StringMap.mem id_name envs.udt_env + then RUserType id_name, SId id_name + else raise (Failure ("Undeclared variable or type " ^ id_name)) | Tuple expr_list -> let sexpr_list = List.map (fun e -> check_expr e envs special_blocks) expr_list in let typs, _ = List.split sexpr_list in @@ -331,13 +338,14 @@ and check_expr expr envs special_blocks = let t = find_func func_name envs.func_env in (* t is return type of this function call *) t.rtyp, SFunctionCall (func_name, sfunc_args) - | UDTAccess (id_name, udt_accessed_member) -> - let udt_typ = find_var id_name envs.var_env in + | UDTAccess (udt_expr, udt_accessed_member) -> + let udt_typ, checked_udt_expr = check_expr udt_expr envs special_blocks in let udt_def = find_udt (string_of_resolved_type udt_typ) envs.udt_env in (match udt_accessed_member with | UDTVariable udt_var -> (match List.assoc_opt udt_var udt_def.members with - | Some accessed_type -> accessed_type, SUDTAccess (id_name, SUDTVariable udt_var) + | Some accessed_type -> + accessed_type, SUDTAccess ((udt_typ, checked_udt_expr), SUDTVariable udt_var) | None -> raise (Failure (udt_var ^ "is not in " ^ string_of_resolved_type udt_typ))) | UDTFunction udt_func -> @@ -350,7 +358,10 @@ and check_expr expr envs special_blocks = in let arg_types, _ = List.split sexpr_list in if arg_types = def_arg_types - then func_sig.rtyp, SUDTAccess (id_name, SUDTFunction (fst udt_func, sexpr_list)) + then + ( func_sig.rtyp + , SUDTAccess + ((udt_typ, checked_udt_expr), SUDTFunction (fst udt_func, sexpr_list)) ) else raise (Failure "Incorrect types passed to this method") | None -> raise @@ -369,7 +380,15 @@ and check_expr expr envs special_blocks = then func_sig.rtyp, SUDTStaticAccess (udt_name, (func_name, sexpr_list)) else raise (Failure "Incorrect types passed to this method") | None -> raise (Failure (func_name ^ "is not a method bound to " ^ udt_name))) - | EnumAccess (enum_name, variant) -> + | EnumAccess (enum_expr, variant) -> + let t_enum, checked_enum_expr = check_expr enum_expr envs special_blocks in + let enum_name = + match enum_expr with + | Id name -> name + | _ -> + Printf.eprintf "EnumAccess base: %s\n" (Utils.string_of_expr enum_expr); + failwith "EnumAccess base must be an identifier" + in let enum_variants = try StringMap.find enum_name envs.enum_env with | Not_found -> raise (Failure ("Undefined enum " ^ enum_name)) @@ -383,7 +402,7 @@ and check_expr expr envs special_blocks = in if not variant_exists then raise (Failure ("Undefined variant " ^ variant ^ " in enum " ^ enum_name)) - else REnumType enum_name, SEnumAccess (enum_name, variant) + else REnumType enum_name, SEnumAccess ((t_enum, checked_enum_expr), variant) | Index (e1, e2) -> let t1, e1' = check_expr e1 envs special_blocks in let t2, e2' = check_expr e2 envs special_blocks in diff --git a/lib/utils/prints.ml b/lib/utils/prints.ml index 54ec52d..7a36801 100644 --- a/lib/utils/prints.ml +++ b/lib/utils/prints.ml @@ -165,7 +165,8 @@ let rec string_of_expr = function func_name ^ "(" ^ String.concat ", " (List.map string_of_expr func_args) ^ ")" | UDTInstance (udt_name, udt_members) -> udt_name ^ "{" ^ string_of_udt_instance udt_members ^ "}" - | UDTAccess (udt_name, udt_access) -> udt_name ^ "." ^ string_of_udt_access udt_access + | UDTAccess (udt_expr, udt_access) -> + string_of_expr udt_expr ^ "." ^ string_of_udt_access udt_access | UDTStaticAccess (udt_name, udt_function) -> udt_name ^ "::" ^ fst udt_function ^ "(" ^ String.concat ", " (List.map string_of_expr (snd udt_function)) @@ -175,7 +176,7 @@ let rec string_of_expr = function | Match (e1, case_list) -> "match (" ^ string_of_expr e1 ^ ") {\n" ^ string_of_case_list case_list ^ "}" | Wildcard -> "_" - | EnumAccess (enum_name, enum_variant) -> enum_name ^ "::" ^ enum_variant + | EnumAccess (enum_expr, enum_variant) -> string_of_expr enum_expr ^ "::" ^ enum_variant | TypeCast (type_name, e) -> string_of_expr e ^ " as " ^ string_of_type type_name and string_of_pattern = function @@ -323,7 +324,8 @@ let rec string_of_sexpr = function ^ ")" | SUDTInstance (udt_name, udt_members) -> udt_name ^ "{" ^ string_of_sudt_instance udt_members ^ "}" - | SUDTAccess (udt_name, udt_access) -> udt_name ^ "." ^ string_of_udt_access udt_access + | SUDTAccess (udt_expr, udt_access) -> + string_of_sexpr (snd udt_expr) ^ "." ^ string_of_udt_access udt_access | SUDTStaticAccess (udt_name, udt_function) -> udt_name ^ "::" ^ fst udt_function ^ "(" ^ String.concat ", " (List.map string_of_sexpr (List.map snd (snd udt_function))) @@ -333,7 +335,8 @@ let rec string_of_sexpr = function | SMatch (e1, case_list) -> "match (" ^ string_of_sexpr (snd e1) ^ ") {\n" ^ string_of_scase_list case_list ^ "}" | SWildcard -> "_" - | SEnumAccess (enum_name, enum_variant) -> enum_name ^ "::" ^ enum_variant + | SEnumAccess (enum_expr, enum_variant) -> + string_of_sexpr (snd enum_expr) ^ "::" ^ enum_variant | STypeCast (type_name, e) -> string_of_sexpr (snd e) ^ " as " ^ string_of_resolved_type type_name diff --git a/test/ir/dune b/test/ir/dune index ee9b955..22d382e 100644 --- a/test/ir/dune +++ b/test/ir/dune @@ -28,4 +28,4 @@ (test (name test_list) - (libraries fly_lib ounit2)) + (libraries fly_lib ounit2 str)) diff --git a/test/ir/test_list.ml b/test/ir/test_list.ml index 4e260fa..1c8b4fe 100644 --- a/test/ir/test_list.ml +++ b/test/ir/test_list.ml @@ -123,38 +123,73 @@ let tests = in (* _write_to_file actual "actual.out"; *) assert_equal expected actual ~printer) - ; ("local_string_list_index" + ; ("list_of_structs_index_field" >:: fun _ -> let sast = get_sast - "fun function() -> string {let a := [\"hello\", \"world\"]; return a[0];}" + "type Point { x:int, y:int } fun main() -> int { let p1 := Point{x:1, \ + y:2}; let p2 := Point{x:3, y:4}; let arr := [p1, p2]; return arr[1].y; }" in let mdl = Irgen.translate sast in let actual = L.string_of_llmodule mdl in let expected = "; ModuleID = 'Fly'\n\ source_filename = \"Fly\"\n\n\ - @str = private unnamed_addr constant [6 x i8] c\"hello\\00\", align 1\n\ - @str.1 = private unnamed_addr constant [6 x i8] c\"world\\00\", align 1\n\n\ - define i8* @function() {\n\ + define i32 @main() {\n\ entry:\n\ - \ %list = alloca i8*, i32 2, align 8\n\ - \ %index = getelementptr inbounds i8*, i8** %list, i32 0\n\ - \ store i8* getelementptr inbounds ([6 x i8], [6 x i8]* @str, i32 0, i32 \ - 0), i8** %index, align 8\n\ - \ %index1 = getelementptr inbounds i8*, i8** %list, i32 1\n\ - \ store i8* getelementptr inbounds ([6 x i8], [6 x i8]* @str.1, i32 0, i32 \ - 0), i8** %index1, align 8\n\ - \ %a = alloca i8**, align 8\n\ - \ store i8** %list, i8*** %a, align 8\n\ - \ %a2 = load i8**, i8*** %a, align 8\n\ - \ %elem_ptr = getelementptr i8*, i8** %a2, i32 0\n\ - \ %elem_val = load i8*, i8** %elem_ptr, align 8\n\ - \ ret i8* %elem_val\n\ + \ %Point_inst = alloca { i32, i32 }, align 8\n\ + \ %Point_x = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst, i32 0, i32 0\n\ + \ store i32 1, i32* %Point_x, align 4\n\ + \ %Point_y = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst, i32 0, i32 1\n\ + \ store i32 2, i32* %Point_y, align 4\n\ + \ %Point_inst1 = alloca { i32, i32 }, align 8\n\ + \ %Point_x2 = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst1, i32 0, i32 0\n\ + \ store i32 3, i32* %Point_x2, align 4\n\ + \ %Point_y3 = getelementptr inbounds { i32, i32 }, { i32, i32 }* \ + %Point_inst1, i32 0, i32 1\n\ + \ store i32 4, i32* %Point_y3, align 4\n\ + \ %list = alloca { i32, i32 }, i32 2, align 8\n\ + \ %index = getelementptr inbounds { i32, i32 }, { i32, i32 }* %list, i32 0\n\ + \ store { i32, i32 }* %Point_inst, { i32, i32 }* %index, align 8\n\ + \ %index4 = getelementptr inbounds { i32, i32 }, { i32, i32 }* %list, i32 1\n\ + \ store { i32, i32 }* %Point_inst1, { i32, i32 }* %index4, align 8\n\ + \ %arr = alloca { i32, i32 }*, align 8\n\ + \ store { i32, i32 }* %list, { i32, i32 }** %arr, align 8\n\ + \ %arr5 = load { i32, i32 }*, { i32, i32 }** %arr, align 8\n\ + \ %elem_ptr = getelementptr { i32, i32 }, { i32, i32 }* %arr5, i32 1\n\ + \ %arr_y = getelementptr inbounds { i32, i32 }, { i32, i32 }* %elem_ptr, \ + i32 0, i32 1\n\ + \ %arr_y_val = load i32, i32* %arr_y, align 4\n\ + \ ret i32 %arr_y_val\n\ }\n" in - (* _write_to_file actual "actual.out"; *) assert_equal expected actual ~printer) + ; ("list_of_enums_index" + >:: fun _ -> + let sast = + get_sast + "enum Color { Red, Green, Blue } fun main() -> Color { let arr := \ + [Color::Red, Color::Blue]; return arr[1]; }" + in + let mdl = Irgen.translate sast in + let actual = L.string_of_llmodule mdl in + assert_bool + "alloca for enum list" + (try + ignore (Str.search_forward (Str.regexp "alloca i32, i32 2") actual 0); + true + with + | Not_found -> false); + assert_bool + "return enum value" + (try + ignore (Str.search_forward (Str.regexp "ret i32") actual 0); + true + with + | Not_found -> false)) ] ;;