From 4de7e7a5fbedb79a76c2079ee028e5c8e054f3fd Mon Sep 17 00:00:00 2001 From: Jordan Mecom Date: Sun, 28 Dec 2025 18:20:36 -0800 Subject: [PATCH 1/2] Add generic types with monomorphization --- capc/src/ast.rs | 8 + capc/src/codegen/emit.rs | 9 + capc/src/hir.rs | 5 + capc/src/parser.rs | 96 +- capc/src/typeck/check.rs | 400 ++++++- capc/src/typeck/collect.rs | 84 +- capc/src/typeck/lower.rs | 47 +- capc/src/typeck/mod.rs | 176 ++- capc/src/typeck/monomorphize.rs | 1050 +++++++++++++++++ capc/tests/parser.rs | 6 + capc/tests/run.rs | 14 + .../parser__snapshot_basic_module.snap | 3 + .../parser__snapshot_doc_comments.snap | 4 + .../parser__snapshot_generics_basic.snap | 669 +++++++++++ .../parser__snapshot_struct_and_match.snap | 6 + .../parser__snapshot_struct_literal.snap | 3 + stdlib/sys/buffer.cap | 4 +- tests/programs/generics_basic.cap | 20 + 18 files changed, 2523 insertions(+), 81 deletions(-) create mode 100644 capc/src/typeck/monomorphize.rs create mode 100644 capc/tests/snapshots/parser__snapshot_generics_basic.snap create mode 100644 tests/programs/generics_basic.cap diff --git a/capc/src/ast.rs b/capc/src/ast.rs index 36dd25f..05bd6f1 100644 --- a/capc/src/ast.rs +++ b/capc/src/ast.rs @@ -59,6 +59,7 @@ pub enum Item { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Function { pub name: Ident, + pub type_params: Vec, pub params: Vec, pub ret: Type, pub body: Block, @@ -70,6 +71,7 @@ pub struct Function { #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExternFunction { pub name: Ident, + pub type_params: Vec, pub params: Vec, pub ret: Type, pub is_pub: bool, @@ -79,6 +81,7 @@ pub struct ExternFunction { #[derive(Debug, Clone, PartialEq, Eq)] pub struct ImplBlock { + pub type_params: Vec, pub target: Type, pub methods: Vec, pub doc: Option, @@ -94,6 +97,7 @@ pub struct Param { #[derive(Debug, Clone, PartialEq, Eq)] pub struct StructDecl { pub name: Ident, + pub type_params: Vec, pub fields: Vec, pub is_pub: bool, pub is_opaque: bool, @@ -107,6 +111,7 @@ pub struct StructDecl { #[derive(Debug, Clone, PartialEq, Eq)] pub struct EnumDecl { pub name: Ident, + pub type_params: Vec, pub variants: Vec, pub is_pub: bool, pub doc: Option, @@ -215,6 +220,7 @@ pub enum Expr { #[derive(Debug, Clone, PartialEq, Eq)] pub struct StructLiteralExpr { pub path: Path, + pub type_args: Vec, pub fields: Vec, pub span: Span, } @@ -270,6 +276,7 @@ impl fmt::Display for Path { #[derive(Debug, Clone, PartialEq, Eq)] pub struct CallExpr { pub callee: Box, + pub type_args: Vec, pub args: Vec, pub span: Span, } @@ -285,6 +292,7 @@ pub struct FieldAccessExpr { pub struct MethodCallExpr { pub receiver: Box, pub method: Ident, + pub type_args: Vec, pub args: Vec, pub span: Span, } diff --git a/capc/src/codegen/emit.rs b/capc/src/codegen/emit.rs index ac2ca8e..8b41f9d 100644 --- a/capc/src/codegen/emit.rs +++ b/capc/src/codegen/emit.rs @@ -1216,6 +1216,9 @@ fn store_value_by_ty( module, ) } + Ty::Param(_) => Err(CodegenError::Unsupported( + "generic type parameters must be monomorphized before codegen".to_string(), + )), Ty::Path(name, args) => { if name == "Result" && args.len() == 2 { let ValueRepr::Result { tag, ok, err } = value else { @@ -1368,6 +1371,9 @@ fn load_value_by_ty( module, ) } + Ty::Param(_) => Err(CodegenError::Unsupported( + "generic type parameters must be monomorphized before codegen".to_string(), + )), Ty::Path(name, args) => { if name == "Result" && args.len() == 2 { let AbiType::Result(ok_abi, err_abi) = &ty.abi else { @@ -2212,6 +2218,9 @@ fn zero_value_for_ty( }; zero_value_for_ty(builder, &inner_ty, ptr_ty, _struct_layouts) } + Ty::Param(_) => Err(CodegenError::Unsupported( + "generic type parameters must be monomorphized before codegen".to_string(), + )), Ty::Path(name, args) => { if name == "Result" && args.len() == 2 { let AbiType::Result(ok_abi, err_abi) = &ty.abi else { diff --git a/capc/src/hir.rs b/capc/src/hir.rs index 586e070..98759e4 100644 --- a/capc/src/hir.rs +++ b/capc/src/hir.rs @@ -46,6 +46,7 @@ pub struct HirProgram { #[derive(Debug, Clone)] pub struct HirFunction { pub name: String, + pub type_params: Vec, pub params: Vec, pub ret_ty: HirType, pub body: HirBlock, @@ -54,6 +55,7 @@ pub struct HirFunction { #[derive(Debug, Clone)] pub struct HirExternFunction { pub name: String, + pub type_params: Vec, pub params: Vec, pub ret_ty: HirType, } @@ -68,6 +70,7 @@ pub struct HirParam { #[derive(Debug, Clone)] pub struct HirStruct { pub name: String, + pub type_params: Vec, pub fields: Vec, pub is_opaque: bool, } @@ -82,6 +85,7 @@ pub struct HirField { #[derive(Debug, Clone)] pub struct HirEnum { pub name: String, + pub type_params: Vec, pub variants: Vec, } @@ -238,6 +242,7 @@ pub struct HirEnumVariantExpr { #[derive(Debug, Clone)] pub struct HirCall { pub callee: ResolvedCallee, + pub type_args: Vec, pub args: Vec, pub ret_ty: HirType, pub span: Span, diff --git a/capc/src/parser.rs b/capc/src/parser.rs index a54d90d..76c9514 100644 --- a/capc/src/parser.rs +++ b/capc/src/parser.rs @@ -231,6 +231,7 @@ impl Parser { fn parse_impl_block(&mut self, impl_doc: Option) -> Result { let start = self.expect(TokenKind::Impl)?.span.start; + let type_params = self.parse_type_params()?; let target = self.parse_type()?; self.expect(TokenKind::LBrace)?; let mut methods = Vec::new(); @@ -248,6 +249,7 @@ impl Parser { Ok(ImplBlock { target, methods, + type_params, doc: impl_doc, span: Span::new(start, end), }) @@ -257,6 +259,7 @@ impl Parser { let start = self.expect(TokenKind::Extern)?.span.start; self.expect(TokenKind::Fn)?; let name = self.expect_ident()?; + let type_params = self.parse_type_params()?; self.expect(TokenKind::LParen)?; let mut params = Vec::new(); if self.peek_kind() != Some(TokenKind::RParen) { @@ -285,6 +288,7 @@ impl Parser { .map_or(ret.span().end, |t| t.span.end); Ok(ExternFunction { name, + type_params, params, ret, is_pub, @@ -296,6 +300,7 @@ impl Parser { fn parse_function(&mut self, is_pub: bool, doc: Option) -> Result { let start = self.expect(TokenKind::Fn)?.span.start; let name = self.expect_ident()?; + let type_params = self.parse_type_params()?; self.expect(TokenKind::LParen)?; let mut params = Vec::new(); if self.peek_kind() != Some(TokenKind::RParen) { @@ -329,6 +334,7 @@ impl Parser { let span = Span::new(start, body.span.end); Ok(Function { name, + type_params, params, ret, body, @@ -349,6 +355,7 @@ impl Parser { ) -> Result { let start = self.expect(TokenKind::Struct)?.span.start; let name = self.expect_ident()?; + let type_params = self.parse_type_params()?; let mut fields = Vec::new(); let end = if self.peek_kind() == Some(TokenKind::LBrace) { self.bump(); @@ -382,6 +389,7 @@ impl Parser { }; Ok(StructDecl { name, + type_params, fields, is_pub, is_opaque, @@ -396,6 +404,7 @@ impl Parser { fn parse_enum(&mut self, is_pub: bool, doc: Option) -> Result { let start = self.expect(TokenKind::Enum)?.span.start; let name = self.expect_ident()?; + let type_params = self.parse_type_params()?; self.expect(TokenKind::LBrace)?; let mut variants = Vec::new(); if self.peek_kind() != Some(TokenKind::RBrace) { @@ -429,6 +438,7 @@ impl Parser { let end = self.expect(TokenKind::RBrace)?.span.end; Ok(EnumDecl { name, + type_params, variants, is_pub, doc, @@ -592,6 +602,11 @@ impl Parser { let start = lhs.span().start; self.bump(); // consume '.' let field = self.expect_ident()?; + let type_args = if self.peek_kind() == Some(TokenKind::LBracket) { + self.parse_type_args()? + } else { + Vec::new() + }; // Check if this is a struct literal (followed by '{') if self.peek_kind() == Some(TokenKind::LBrace) { @@ -606,7 +621,7 @@ impl Parser { }; path.segments.push(field); path.span = Span::new(path.span.start, path.segments.last().unwrap().span.end); - lhs = self.parse_struct_literal(path)?; + lhs = self.parse_struct_literal(path, type_args)?; continue; } @@ -626,12 +641,19 @@ impl Parser { lhs = Expr::MethodCall(MethodCallExpr { receiver: Box::new(lhs), method: field, + type_args, args, span: Span::new(start, end), }); continue; } + if !type_args.is_empty() { + return Err(self.error_current( + "type arguments require a method call or struct literal" + .to_string(), + )); + } // Otherwise, it's a field access let span = Span::new(start, field.span.end); lhs = Expr::FieldAccess(FieldAccessExpr { @@ -642,7 +664,30 @@ impl Parser { continue; } TokenKind::LParen => { - lhs = self.finish_call(lhs)?; + lhs = self.finish_call(lhs, Vec::new())?; + continue; + } + TokenKind::LBracket => { + let type_args = self.parse_type_args()?; + if self.peek_kind() == Some(TokenKind::LBrace) { + let path = match lhs { + Expr::Path(p) => p, + Expr::FieldAccess(ref fa) => self.field_access_to_path(fa)?, + _ => { + return Err(self.error_current( + "expected path before struct literal".to_string(), + )) + } + }; + lhs = self.parse_struct_literal(path, type_args)?; + continue; + } + if self.peek_kind() != Some(TokenKind::LParen) { + return Err(self.error_current( + "type arguments require a call or struct literal".to_string(), + )); + } + lhs = self.finish_call(lhs, type_args)?; continue; } TokenKind::Question => { @@ -811,7 +856,7 @@ impl Parser { }; if self.peek_kind() == Some(TokenKind::LBrace) { - self.parse_struct_literal(path) + self.parse_struct_literal(path, Vec::new()) } else { Ok(Expr::Path(path)) } @@ -999,7 +1044,7 @@ impl Parser { Ok(Type::Path { path, args, span }) } - fn parse_struct_literal(&mut self, path: Path) -> Result { + fn parse_struct_literal(&mut self, path: Path, type_args: Vec) -> Result { let start = path.span.start; self.expect(TokenKind::LBrace)?; let mut fields = Vec::new(); @@ -1026,12 +1071,13 @@ impl Parser { let end = self.expect(TokenKind::RBrace)?.span.end; Ok(Expr::StructLiteral(StructLiteralExpr { path, + type_args, fields, span: Span::new(start, end), })) } - fn finish_call(&mut self, callee: Expr) -> Result { + fn finish_call(&mut self, callee: Expr, type_args: Vec) -> Result { let start = callee.span().start; self.expect(TokenKind::LParen)?; let mut args = Vec::new(); @@ -1046,11 +1092,49 @@ impl Parser { let end = self.expect(TokenKind::RParen)?.span.end; Ok(Expr::Call(CallExpr { callee: Box::new(callee), + type_args, args, span: Span::new(start, end), })) } + fn parse_type_params(&mut self) -> Result, ParseError> { + if self.peek_kind() != Some(TokenKind::LBracket) { + return Ok(Vec::new()); + } + self.bump(); + let mut params = Vec::new(); + if self.peek_kind() != Some(TokenKind::RBracket) { + loop { + let ident = self.expect_ident()?; + params.push(ident); + if self.maybe_consume(TokenKind::Comma).is_none() { + break; + } + } + } + self.expect(TokenKind::RBracket)?; + Ok(params) + } + + fn parse_type_args(&mut self) -> Result, ParseError> { + if self.peek_kind() != Some(TokenKind::LBracket) { + return Ok(Vec::new()); + } + self.bump(); + let mut args = Vec::new(); + if self.peek_kind() != Some(TokenKind::RBracket) { + loop { + args.push(self.parse_type()?); + if self.maybe_consume(TokenKind::Comma).is_none() { + break; + } + } + } + self.expect(TokenKind::RBracket)?; + Ok(args) + } + fn expect(&mut self, kind: TokenKind) -> Result { match self.peek_kind() { Some(k) if k == kind => Ok(self.bump().unwrap()), @@ -1131,7 +1215,7 @@ fn infix_binding_power(op: &BinaryOp) -> (u8, u8) { fn postfix_binding_power(kind: &TokenKind) -> Option { match kind { - TokenKind::Dot | TokenKind::LParen | TokenKind::Question => Some(13), + TokenKind::Dot | TokenKind::LParen | TokenKind::LBracket | TokenKind::Question => Some(13), _ => None, } } diff --git a/capc/src/typeck/check.rs b/capc/src/typeck/check.rs index fdaedac..ccd62e9 100644 --- a/capc/src/typeck/check.rs +++ b/capc/src/typeck/check.rs @@ -4,10 +4,10 @@ use crate::ast::*; use crate::error::TypeError; use super::{ - is_affine_type, is_numeric_type, is_orderable_type, lower_type, resolve_enum_variant, - resolve_method_target, resolve_path, resolve_type_name, type_contains_ref, type_kind, - BuiltinType, EnumInfo, FunctionSig, MoveState, Scopes, SpanExt, StdlibIndex, StructInfo, Ty, - TypeKind, TypeTable, UseMap, UseMode, + build_type_params, is_affine_type, is_numeric_type, is_orderable_type, lower_type, + resolve_enum_variant, resolve_method_target, resolve_path, resolve_type_name, type_contains_ref, + type_kind, validate_type_args, BuiltinType, EnumInfo, FunctionSig, MoveState, Scopes, + SpanExt, StdlibIndex, StructInfo, Ty, TypeKind, TypeTable, UseMap, UseMode, }; /// Optional recorder for expression types during checking. @@ -224,6 +224,7 @@ pub(super) fn check_function( module_name: &str, type_table: Option<&mut TypeTable>, ) -> Result<(), TypeError> { + let type_params = build_type_params(&func.type_params)?; let mut params_map = HashMap::new(); for param in &func.params { let Some(ty) = ¶m.ty else { @@ -250,13 +251,15 @@ pub(super) fn check_function( } } } - let ty = lower_type(ty, use_map, stdlib)?; + let ty = lower_type(ty, use_map, stdlib, &type_params)?; + validate_type_args(&ty, struct_map, enum_map, param.ty.as_ref().unwrap().span())?; params_map.insert(param.name.item.clone(), ty); } let mut scopes = Scopes::from_flat_map(params_map); let mut recorder = TypeRecorder::new(type_table); - let ret_ty = lower_type(&func.ret, use_map, stdlib)?; + let ret_ty = lower_type(&func.ret, use_map, stdlib, &type_params)?; + validate_type_args(&ret_ty, struct_map, enum_map, func.ret.span())?; if let Some(span) = type_contains_ref(&func.ret) { return Err(TypeError::new( "reference types cannot be returned".to_string(), @@ -276,6 +279,7 @@ pub(super) fn check_function( enum_map, stdlib, module_name, + &type_params, )?; } @@ -312,6 +316,7 @@ fn check_stmt( enum_map: &HashMap, stdlib: &StdlibIndex, module_name: &str, + type_params: &HashSet, ) -> Result<(), TypeError> { match stmt { Stmt::Let(let_stmt) => { @@ -342,6 +347,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )?; let final_ty = if let Some(annot) = &let_stmt.ty { if let Some(span) = type_contains_ref(annot) { @@ -363,7 +369,8 @@ fn check_stmt( } } } - let annot_ty = lower_type(annot, use_map, stdlib)?; + let annot_ty = lower_type(annot, use_map, stdlib, type_params)?; + validate_type_args( &annot_ty, struct_map, enum_map, annot.span())?; let matches_ref = if let Ty::Ref(inner) = &annot_ty { &expr_ty == inner.as_ref() || &expr_ty == &annot_ty } else { @@ -435,6 +442,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )?; if expr_ty != existing { return Err(TypeError::new( @@ -460,6 +468,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )? } else { Ty::Builtin(BuiltinType::Unit) @@ -485,6 +494,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )?; if cond_ty != Ty::Builtin(BuiltinType::Bool) { return Err(TypeError::new( @@ -504,6 +514,7 @@ fn check_stmt( enum_map, stdlib, module_name, + type_params, )?; let mut else_scopes = scopes.clone(); if let Some(block) = &if_stmt.else_block { @@ -518,6 +529,7 @@ fn check_stmt( enum_map, stdlib, module_name, + type_params, )?; } merge_branch_states( @@ -542,6 +554,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )?; if cond_ty != Ty::Builtin(BuiltinType::Bool) { return Err(TypeError::new( @@ -561,6 +574,7 @@ fn check_stmt( enum_map, stdlib, module_name, + type_params, )?; ensure_affine_states_match( scopes, @@ -584,6 +598,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )?; } else { check_expr( @@ -598,6 +613,7 @@ fn check_stmt( stdlib, ret_ty, module_name, + type_params, )?; } } @@ -618,6 +634,7 @@ fn check_block( enum_map: &HashMap, stdlib: &StdlibIndex, module_name: &str, + type_params: &HashSet, ) -> Result<(), TypeError> { scopes.push_scope(); for stmt in &block.stmts { @@ -632,6 +649,7 @@ fn check_block( enum_map, stdlib, module_name, + type_params, )?; } ensure_linear_scope_consumed(scopes, struct_map, enum_map, block.span)?; @@ -831,6 +849,7 @@ pub(super) fn check_expr( stdlib: &StdlibIndex, ret_ty: &Ty, module_name: &str, + type_params: &HashSet, ) -> Result { let ty = match expr { Expr::Literal(lit) => match &lit.value { @@ -894,6 +913,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; return record_expr_type(recorder, expr, Ty::Builtin(BuiltinType::Unit)); } @@ -916,6 +936,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if let Ty::Path(ty_name, args) = ret_ty { if ty_name == "Result" && args.len() == 2 { @@ -963,17 +984,32 @@ pub(super) fn check_expr( call.span, )); } - if sig.params.len() != call.args.len() { + let explicit_type_args = lower_type_args( + &call.type_args, + use_map, + stdlib, + struct_map, + enum_map, + type_params, + )?; + let subs = build_call_substitution(sig, &explicit_type_args, HashMap::new(), call.span)?; + let instantiated_params: Vec = sig + .params + .iter() + .map(|ty| substitute_type(ty, &subs)) + .collect(); + let instantiated_ret = substitute_type(&sig.ret, &subs); + if instantiated_params.len() != call.args.len() { return Err(TypeError::new( format!( "argument count mismatch: expected {}, found {}", - sig.params.len(), + instantiated_params.len(), call.args.len() ), call.span, )); } - for (arg, expected) in call.args.iter().zip(&sig.params) { + for (arg, expected) in call.args.iter().zip(&instantiated_params) { let (expected_inner, use_mode) = if let Ty::Ref(inner) = expected { (inner.as_ref(), UseMode::Read) } else { @@ -991,6 +1027,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if !matches!(expected, Ty::Ref(_)) && matches!(arg_ty, Ty::Ref(_)) { return Err(TypeError::new( @@ -1010,7 +1047,7 @@ pub(super) fn check_expr( )); } } - Ok(sig.ret.clone()) + Ok(instantiated_ret) } Expr::MethodCall(method_call) => { fn get_leftmost_segment(expr: &Expr) -> Option<&str> { @@ -1056,17 +1093,33 @@ pub(super) fn check_expr( method_call.span, )); } - if sig.params.len() != method_call.args.len() { + let explicit_type_args = lower_type_args( + &method_call.type_args, + use_map, + stdlib, + struct_map, + enum_map, + type_params, + )?; + let subs = + build_call_substitution(sig, &explicit_type_args, HashMap::new(), method_call.span)?; + let instantiated_params: Vec = sig + .params + .iter() + .map(|ty| substitute_type(ty, &subs)) + .collect(); + let instantiated_ret = substitute_type(&sig.ret, &subs); + if instantiated_params.len() != method_call.args.len() { return Err(TypeError::new( format!( "argument count mismatch: expected {}, found {}", - sig.params.len(), + instantiated_params.len(), method_call.args.len() ), method_call.span, )); } - for (arg, expected) in method_call.args.iter().zip(&sig.params) { + for (arg, expected) in method_call.args.iter().zip(&instantiated_params) { let (expected_inner, use_mode) = if let Ty::Ref(inner) = expected { (inner.as_ref(), UseMode::Read) } else { @@ -1084,6 +1137,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if !matches!(expected, Ty::Ref(_)) && matches!(arg_ty, Ty::Ref(_)) { return Err(TypeError::new( @@ -1105,7 +1159,7 @@ pub(super) fn check_expr( )); } } - return record_expr_type(recorder, expr, sig.ret.clone()); + return record_expr_type(recorder, expr, instantiated_ret); } let receiver_ty = check_expr( @@ -1120,6 +1174,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if let Ty::Path(name, args) = &receiver_ty { if name == "Result" && args.len() == 2 { @@ -1145,6 +1200,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if &arg_ty != ok_ty { return Err(TypeError::new( @@ -1175,6 +1231,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if &arg_ty != err_ty { return Err(TypeError::new( @@ -1217,11 +1274,53 @@ pub(super) fn check_expr( method_call.span, )); } - if sig.params.len() != method_call.args.len() + 1 { + let mut inferred = HashMap::new(); + let expected_receiver = match &sig.params[0] { + Ty::Ref(inner) | Ty::Ptr(inner) => inner.as_ref(), + other => other, + }; + let actual_receiver = match &receiver_ty { + Ty::Ref(inner) | Ty::Ptr(inner) => inner.as_ref(), + other => other, + }; + let normalized_actual_receiver = match (expected_receiver, actual_receiver) { + (Ty::Path(expected_name, _), Ty::Path(actual_name, args)) + if !expected_name.contains('.') + && actual_name + .rsplit_once('.') + .map(|(_, t)| t == expected_name) + .unwrap_or(false) => + { + Ty::Path(expected_name.clone(), args.clone()) + } + _ => actual_receiver.clone(), + }; + match_type_params( + expected_receiver, + &normalized_actual_receiver, + &mut inferred, + method_call.receiver.span(), + )?; + let explicit_type_args = lower_type_args( + &method_call.type_args, + use_map, + stdlib, + struct_map, + enum_map, + type_params, + )?; + let subs = build_call_substitution(sig, &explicit_type_args, inferred, method_call.span)?; + let instantiated_params: Vec = sig + .params + .iter() + .map(|ty| substitute_type(ty, &subs)) + .collect(); + let instantiated_ret = substitute_type(&sig.ret, &subs); + if instantiated_params.len() != method_call.args.len() + 1 { return Err(TypeError::new( format!( "argument count mismatch: expected {}, found {}", - sig.params.len() - 1, + instantiated_params.len() - 1, method_call.args.len() ), method_call.span, @@ -1236,9 +1335,21 @@ pub(super) fn check_expr( let receiver_ref_unqualified = Ty::Ref(Box::new(receiver_unqualified.clone())); let receiver_ptr = Ty::Ptr(Box::new(receiver_base.clone())); let receiver_ptr_unqualified = Ty::Ptr(Box::new(receiver_unqualified.clone())); + let expected_qualified = match &instantiated_params[0] { + Ty::Path(name, args) if !name.contains('.') => { + Some(Ty::Path(format!("{method_module}.{name}"), args.clone())) + } + _ => None, + }; + let expected_ref_qualified = expected_qualified + .as_ref() + .map(|ty| Ty::Ref(Box::new(ty.clone()))); + let expected_ptr_qualified = expected_qualified + .as_ref() + .map(|ty| Ty::Ptr(Box::new(ty.clone()))); - let expects_ref = matches!(sig.params[0], Ty::Ref(_)); - let expects_ptr = matches!(sig.params[0], Ty::Ptr(_)); + let expects_ref = matches!(instantiated_params[0], Ty::Ref(_)); + let expects_ptr = matches!(instantiated_params[0], Ty::Ptr(_)); if matches!(receiver_ty, Ty::Ref(_)) && !expects_ref { return Err(TypeError::new( @@ -1253,22 +1364,25 @@ pub(super) fn check_expr( )); } - if sig.params[0] != receiver_ty - && sig.params[0] != receiver_unqualified - && sig.params[0] != receiver_ref - && sig.params[0] != receiver_ref_unqualified - && sig.params[0] != receiver_ptr - && sig.params[0] != receiver_ptr_unqualified + if instantiated_params[0] != receiver_ty + && expected_qualified.as_ref() != Some(&receiver_ty) + && instantiated_params[0] != receiver_unqualified + && instantiated_params[0] != receiver_ref + && expected_ref_qualified.as_ref() != Some(&receiver_ref) + && instantiated_params[0] != receiver_ref_unqualified + && instantiated_params[0] != receiver_ptr + && expected_ptr_qualified.as_ref() != Some(&receiver_ptr) + && instantiated_params[0] != receiver_ptr_unqualified { return Err(TypeError::new( format!( "method receiver type mismatch: expected {expected:?}, found {receiver_ty:?}", - expected = sig.params[0] + expected = instantiated_params[0] ), method_call.receiver.span(), )); } - if sig.params[0] != receiver_ref && sig.params[0] != receiver_ref_unqualified { + if instantiated_params[0] != receiver_ref && instantiated_params[0] != receiver_ref_unqualified { let _ = check_expr( &method_call.receiver, functions, @@ -1281,9 +1395,10 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; } - for (arg, expected) in method_call.args.iter().zip(&sig.params[1..]) { + for (arg, expected) in method_call.args.iter().zip(&instantiated_params[1..]) { let (expected_inner, use_mode) = if let Ty::Ref(inner) = expected { (inner.as_ref(), UseMode::Read) } else { @@ -1301,6 +1416,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; if !matches!(expected, Ty::Ref(_)) && matches!(arg_ty, Ty::Ref(_)) { return Err(TypeError::new( @@ -1320,7 +1436,7 @@ pub(super) fn check_expr( )); } } - Ok(sig.ret.clone()) + Ok(instantiated_ret) } Expr::StructLiteral(lit) => check_struct_literal( lit, @@ -1333,6 +1449,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, ), Expr::Unary(unary) => { let expr_ty = check_expr( @@ -1347,6 +1464,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; match unary.op { UnaryOp::Neg => { @@ -1386,6 +1504,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; let right = check_expr( &binary.right, @@ -1399,6 +1518,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; match binary.op { BinaryOp::Add @@ -1477,6 +1597,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, ), Expr::Try(try_expr) => { let inner_ty = check_expr( @@ -1491,6 +1612,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; let (ok_ty, err_ty) = match inner_ty { Ty::Path(name, args) if name == "Result" && args.len() == 2 => { @@ -1537,6 +1659,7 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, ), Expr::FieldAccess(field_access) => { fn get_leftmost_path_segment(expr: &Expr) -> Option<&str> { @@ -1575,8 +1698,9 @@ pub(super) fn check_expr( stdlib, ret_ty, module_name, + type_params, )?; - let Ty::Path(struct_name, _) = object_ty else { + let Ty::Path(struct_name, struct_args) = object_ty else { return Err(TypeError::new( "field access requires a struct value".to_string(), field_access.span, @@ -1606,7 +1730,8 @@ pub(super) fn check_expr( field_access.field.span, ) })?; - let field_ty = field_ty.clone(); + let substitutions = build_type_substitution(&info.type_params, &struct_args, field_access.span)?; + let field_ty = substitute_type(field_ty, &substitutions); if is_affine_type(&field_ty, struct_map, enum_map) { match use_mode { UseMode::Read => { @@ -1655,6 +1780,7 @@ fn check_match_stmt( stdlib: &StdlibIndex, ret_ty: &Ty, module_name: &str, + type_params: &HashSet, ) -> Result { let match_ty = check_expr( &match_expr.expr, @@ -1668,6 +1794,7 @@ fn check_match_stmt( stdlib, ret_ty, module_name, + type_params, )?; let mut arm_scopes = Vec::with_capacity(match_expr.arms.len()); for arm in &match_expr.arms { @@ -1685,6 +1812,7 @@ fn check_match_stmt( enum_map, stdlib, module_name, + type_params, )?; arm_scope.pop_scope(); arm_scopes.push(arm_scope); @@ -1714,6 +1842,7 @@ fn check_match_expr_value( stdlib: &StdlibIndex, ret_ty: &Ty, module_name: &str, + type_params: &HashSet, ) -> Result { let match_ty = check_expr( &match_expr.expr, @@ -1727,6 +1856,7 @@ fn check_match_expr_value( stdlib, ret_ty, module_name, + type_params, )?; let mut result_ty: Option = None; let mut arm_scopes = Vec::with_capacity(match_expr.arms.len()); @@ -1745,6 +1875,7 @@ fn check_match_expr_value( stdlib, ret_ty, module_name, + type_params, )?; arm_scope.pop_scope(); arm_scopes.push(arm_scope); @@ -1783,6 +1914,7 @@ fn check_match_arm_value( stdlib: &StdlibIndex, ret_ty: &Ty, module_name: &str, + type_params: &HashSet, ) -> Result { let Some((last, prefix)) = block.stmts.split_last() else { return Err(TypeError::new( @@ -1808,6 +1940,7 @@ fn check_match_arm_value( enum_map, stdlib, module_name, + type_params, )?; } match last { @@ -1823,6 +1956,7 @@ fn check_match_arm_value( stdlib, ret_ty, module_name, + type_params, ), _ => Err(TypeError::new( "match arm must end with expression".to_string(), @@ -1919,7 +2053,7 @@ fn check_match_exhaustive( for arm in arms { if let Pattern::Path(path) = &arm.pattern { if let Some(ty) = resolve_enum_variant(path, use_map, enum_map, module_name) { - if &ty == match_ty { + if same_type_constructor(&ty, match_ty) { if let Some(seg) = path.segments.last() { seen.insert(seg.item.clone()); } @@ -1947,6 +2081,13 @@ fn check_match_exhaustive( Ok(()) } +fn same_type_constructor(left: &Ty, right: &Ty) -> bool { + match (left, right) { + (Ty::Path(left_name, _), Ty::Path(right_name, _)) => left_name == right_name, + _ => left == right, + } +} + /// Check a struct literal and ensure all fields are present and typed. fn check_struct_literal( lit: &StructLiteralExpr, @@ -1959,7 +2100,16 @@ fn check_struct_literal( stdlib: &StdlibIndex, ret_ty: &Ty, module_name: &str, + type_params: &HashSet, ) -> Result { + let type_args = lower_type_args( + &lit.type_args, + use_map, + stdlib, + struct_map, + enum_map, + type_params, + )?; let type_name = resolve_type_name(&lit.path, use_map, stdlib); let key = if lit.path.segments.len() == 1 { if stdlib.types.contains_key(&lit.path.segments[0].item) { @@ -1973,6 +2123,25 @@ fn check_struct_literal( let info = struct_map.get(&key).ok_or_else(|| { TypeError::new(format!("unknown struct `{}`", key), lit.span) })?; + if info.type_params.is_empty() { + if !type_args.is_empty() { + return Err(TypeError::new( + format!("type `{}` does not accept type arguments", key), + lit.span, + )); + } + } else if type_args.len() != info.type_params.len() { + return Err(TypeError::new( + format!( + "type `{}` expects {} type argument(s), found {}", + key, + info.type_params.len(), + type_args.len() + ), + lit.span, + )); + } + let substitutions = build_type_substitution(&info.type_params, &type_args, lit.span)?; if info.is_opaque && info.module != module_name { return Err(TypeError::new( format!( @@ -1991,6 +2160,7 @@ fn check_struct_literal( field.span, ) })?; + let expected = substitute_type(&expected, &substitutions); let actual = check_expr( &field.expr, functions, @@ -2003,6 +2173,7 @@ fn check_struct_literal( stdlib, ret_ty, module_name, + type_params, )?; if actual != expected { return Err(TypeError::new( @@ -2018,7 +2189,168 @@ fn check_struct_literal( )); } - Ok(Ty::Path(type_name, Vec::new())) + Ok(Ty::Path(type_name, type_args)) +} + +fn lower_type_args( + args: &[Type], + use_map: &UseMap, + stdlib: &StdlibIndex, + struct_map: &HashMap, + enum_map: &HashMap, + type_params: &HashSet, +) -> Result, TypeError> { + let mut out = Vec::with_capacity(args.len()); + for arg in args { + let ty = lower_type(arg, use_map, stdlib, type_params)?; + validate_type_args(&ty, struct_map, enum_map, arg.span())?; + out.push(ty); + } + Ok(out) +} + +fn build_type_substitution( + params: &[String], + args: &[Ty], + span: Span, +) -> Result, TypeError> { + if params.len() != args.len() { + return Err(TypeError::new( + format!( + "expected {} type argument(s), found {}", + params.len(), + args.len() + ), + span, + )); + } + let mut map = HashMap::new(); + for (param, arg) in params.iter().zip(args.iter()) { + map.insert(param.clone(), arg.clone()); + } + Ok(map) +} + +fn substitute_type(ty: &Ty, subs: &HashMap) -> Ty { + match ty { + Ty::Param(name) => subs.get(name).cloned().unwrap_or_else(|| ty.clone()), + Ty::Builtin(_) => ty.clone(), + Ty::Ptr(inner) => Ty::Ptr(Box::new(substitute_type(inner, subs))), + Ty::Ref(inner) => Ty::Ref(Box::new(substitute_type(inner, subs))), + Ty::Path(name, args) => Ty::Path( + name.clone(), + args.iter().map(|arg| substitute_type(arg, subs)).collect(), + ), + } +} + +fn match_type_params( + expected: &Ty, + actual: &Ty, + subs: &mut HashMap, + span: Span, +) -> Result<(), TypeError> { + match expected { + Ty::Param(name) => { + if let Some(existing) = subs.get(name) { + if existing != actual { + return Err(TypeError::new( + format!( + "conflicting type arguments for `{}`: {existing:?} vs {actual:?}", + name + ), + span, + )); + } + } else { + subs.insert(name.clone(), actual.clone()); + } + Ok(()) + } + Ty::Builtin(_) => { + if expected != actual { + return Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )); + } + Ok(()) + } + Ty::Ptr(inner) => match actual { + Ty::Ptr(actual_inner) => match_type_params(inner, actual_inner, subs, span), + _ => Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )), + }, + Ty::Ref(inner) => match actual { + Ty::Ref(actual_inner) => match_type_params(inner, actual_inner, subs, span), + _ => Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )), + }, + Ty::Path(name, args) => match actual { + Ty::Path(actual_name, actual_args) => { + if name != actual_name || args.len() != actual_args.len() { + return Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )); + } + for (arg, actual_arg) in args.iter().zip(actual_args.iter()) { + match_type_params(arg, actual_arg, subs, span)?; + } + Ok(()) + } + _ => Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )), + }, + } +} + +fn build_call_substitution( + sig: &FunctionSig, + explicit_args: &[Ty], + inferred: HashMap, + span: Span, +) -> Result, TypeError> { + if sig.type_params.is_empty() { + if !explicit_args.is_empty() { + return Err(TypeError::new( + format!( + "function does not accept type arguments (found {})", + explicit_args.len() + ), + span, + )); + } + return Ok(inferred); + } + + let mut subs = inferred; + let mut remaining = Vec::new(); + for name in &sig.type_params { + if !subs.contains_key(name) { + remaining.push(name.clone()); + } + } + if explicit_args.len() != remaining.len() { + return Err(TypeError::new( + format!( + "expected {} type argument(s), found {}", + remaining.len(), + explicit_args.len() + ), + span, + )); + } + for (name, arg) in remaining.into_iter().zip(explicit_args.iter()) { + subs.insert(name, arg.clone()); + } + Ok(subs) } /// Bind locals introduced by a match pattern. @@ -2065,7 +2397,7 @@ fn bind_pattern( } Pattern::Path(path) => { if let Some(ty) = resolve_enum_variant(path, use_map, enum_map, module_name) { - if &ty != match_ty { + if !same_type_constructor(&ty, match_ty) { return Err(TypeError::new( format!("pattern type mismatch: expected {match_ty:?}, found {ty:?}"), path.span, diff --git a/capc/src/typeck/collect.rs b/capc/src/typeck/collect.rs index 6cc87ed..136467d 100644 --- a/capc/src/typeck/collect.rs +++ b/capc/src/typeck/collect.rs @@ -4,8 +4,8 @@ use crate::ast::*; use crate::error::TypeError; use super::{ - desugar_impl_methods, lower_type, type_contains_ref, EnumInfo, FunctionSig, StructInfo, TypeKind, - UseMap, StdlibIndex, + build_type_params, desugar_impl_methods, lower_type, type_contains_ref, EnumInfo, FunctionSig, + StructInfo, TypeKind, UseMap, StdlibIndex, RESERVED_TYPE_PARAMS, validate_type_args, }; /// Build the stdlib type index for name resolution. @@ -46,11 +46,13 @@ pub(super) fn collect_functions( let mut impl_methods = std::collections::HashSet::new(); for item in &module.items { let mut add_function = |name: &Ident, + type_params: &[Ident], params: &[Param], ret: &Type, span: Span, is_pub: bool| -> Result<(), TypeError> { + let type_param_set = build_type_params(type_params)?; for param in params { if param.ty.is_none() { return Err(TypeError::new( @@ -60,6 +62,7 @@ pub(super) fn collect_functions( } } let sig = FunctionSig { + type_params: type_params.iter().map(|p| p.item.clone()).collect(), params: params .iter() .map(|p| { @@ -67,10 +70,11 @@ pub(super) fn collect_functions( p.ty.as_ref().expect("param type checked"), &local_use, stdlib, + &type_param_set, ) }) .collect::>()?, - ret: lower_type(ret, &local_use, stdlib)?, + ret: lower_type(ret, &local_use, stdlib, &type_param_set)?, module: module_name.clone(), is_pub, }; @@ -101,7 +105,14 @@ pub(super) fn collect_functions( func.name.span, )); } - add_function(&func.name, &func.params, &func.ret, func.name.span, func.is_pub)?; + add_function( + &func.name, + &func.type_params, + &func.params, + &func.ret, + func.name.span, + func.is_pub, + )?; } Item::ExternFunction(func) => { if func.name.item.contains("__") { @@ -110,7 +121,14 @@ pub(super) fn collect_functions( func.name.span, )); } - add_function(&func.name, &func.params, &func.ret, func.name.span, func.is_pub)?; + add_function( + &func.name, + &func.type_params, + &func.params, + &func.ret, + func.name.span, + func.is_pub, + )?; } Item::Impl(impl_block) => { let methods = desugar_impl_methods( @@ -135,6 +153,7 @@ pub(super) fn collect_functions( } add_function( &method.name, + &method.type_params, &method.params, &method.ret, method.name.span, @@ -155,21 +174,19 @@ pub(super) fn collect_structs( entry_name: &str, stdlib: &StdlibIndex, ) -> Result, TypeError> { - let reserved = [ - "i32", "i64", "u32", "u8", "bool", "string", "unit", "Result", - ]; let mut structs = HashMap::new(); for module in modules { let module_name = module.name.to_string(); let local_use = UseMap::new(module); for item in &module.items { if let Item::Struct(decl) = item { - if reserved.contains(&decl.name.item.as_str()) { + if RESERVED_TYPE_PARAMS.contains(&decl.name.item.as_str()) { return Err(TypeError::new( format!("type name `{}` is reserved", decl.name.item), decl.name.span, )); } + let type_param_set = build_type_params(&decl.type_params)?; let mut fields = HashMap::new(); for field in &decl.fields { if let Some(span) = type_contains_ref(&field.ty) { @@ -178,7 +195,7 @@ pub(super) fn collect_structs( span, )); } - let ty = lower_type(&field.ty, &local_use, stdlib)?; + let ty = lower_type(&field.ty, &local_use, stdlib, &type_param_set)?; if fields.insert(field.name.item.clone(), ty).is_some() { return Err(TypeError::new( format!("duplicate field `{}`", field.name.item), @@ -203,6 +220,7 @@ pub(super) fn collect_structs( TypeKind::Unrestricted }; let info = StructInfo { + type_params: decl.type_params.iter().map(|p| p.item.clone()).collect(), fields, is_opaque: decl.is_opaque || decl.is_capability, is_capability: decl.is_capability, @@ -232,21 +250,19 @@ pub(super) fn collect_enums( entry_name: &str, stdlib: &StdlibIndex, ) -> Result, TypeError> { - let reserved = [ - "i32", "i64", "u32", "u8", "bool", "string", "unit", "Result", - ]; let mut enums = HashMap::new(); for module in modules { let module_name = module.name.to_string(); let local_use = UseMap::new(module); for item in &module.items { if let Item::Enum(decl) = item { - if reserved.contains(&decl.name.item.as_str()) { + if RESERVED_TYPE_PARAMS.contains(&decl.name.item.as_str()) { return Err(TypeError::new( format!("type name `{}` is reserved", decl.name.item), decl.name.span, )); } + let type_param_set = build_type_params(&decl.type_params)?; let mut variants = Vec::new(); let mut payloads = HashMap::new(); for variant in &decl.variants { @@ -267,7 +283,7 @@ pub(super) fn collect_enums( span, )); } - Some(lower_type(payload, &local_use, stdlib)?) + Some(lower_type(payload, &local_use, stdlib, &type_param_set)?) } else { None }; @@ -281,6 +297,7 @@ pub(super) fn collect_enums( )); } let info = EnumInfo { + type_params: decl.type_params.iter().map(|p| p.item.clone()).collect(), variants: variants.clone(), payloads: payloads.clone(), }; @@ -301,6 +318,40 @@ pub(super) fn collect_enums( Ok(enums) } +/// Validate type argument arity within struct fields and enum payloads. +pub(super) fn validate_type_defs( + modules: &[&Module], + stdlib: &StdlibIndex, + struct_map: &HashMap, + enum_map: &HashMap, +) -> Result<(), TypeError> { + for module in modules { + let local_use = UseMap::new(module); + for item in &module.items { + match item { + Item::Struct(decl) => { + let type_param_set = build_type_params(&decl.type_params)?; + for field in &decl.fields { + let ty = lower_type(&field.ty, &local_use, stdlib, &type_param_set)?; + validate_type_args(&ty, struct_map, enum_map, field.ty.span())?; + } + } + Item::Enum(decl) => { + let type_param_set = build_type_params(&decl.type_params)?; + for variant in &decl.variants { + if let Some(payload) = &variant.payload { + let ty = lower_type(payload, &local_use, stdlib, &type_param_set)?; + validate_type_args(&ty, struct_map, enum_map, payload.span())?; + } + } + } + _ => {} + } + } + } + Ok(()) +} + /// Enforce that `copy struct` declarations contain only unrestricted fields. pub(super) fn validate_copy_structs( modules: &[&Module], @@ -315,8 +366,9 @@ pub(super) fn validate_copy_structs( if !decl.is_copy { continue; } + let type_param_set = build_type_params(&decl.type_params)?; for field in &decl.fields { - let ty = lower_type(&field.ty, &local_use, stdlib)?; + let ty = lower_type(&field.ty, &local_use, stdlib, &type_param_set)?; if super::type_kind(&ty, struct_map, enum_map) != TypeKind::Unrestricted { return Err(TypeError::new( "copy struct cannot contain move-only fields".to_string(), diff --git a/capc/src/typeck/lower.rs b/capc/src/typeck/lower.rs index 466103a..7b40847 100644 --- a/capc/src/typeck/lower.rs +++ b/capc/src/typeck/lower.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use crate::ast::*; use crate::error::TypeError; @@ -12,9 +12,9 @@ use crate::hir::{ }; use super::{ - check, function_key, lower_type, resolve_enum_variant, resolve_method_target, resolve_type_name, - EnumInfo, FunctionSig, FunctionTypeTables, SpanExt, StdlibIndex, StructInfo, Ty, TypeTable, - UseMap, + build_type_params, check, function_key, lower_type, resolve_enum_variant, resolve_method_target, + resolve_type_name, EnumInfo, FunctionSig, FunctionTypeTables, SpanExt, StdlibIndex, StructInfo, + Ty, TypeTable, UseMap, }; /// Context for HIR lowering (uses the type checker as source of truth). @@ -33,6 +33,7 @@ struct LoweringCtx<'a> { /// Maps variable names to their types (needed for type checking during lowering) local_types: HashMap, local_counter: usize, + type_params: HashSet, } impl<'a> LoweringCtx<'a> { @@ -59,6 +60,7 @@ impl<'a> LoweringCtx<'a> { local_map: HashMap::new(), local_types: HashMap::new(), local_counter: 0, + type_params: HashSet::new(), } } @@ -129,6 +131,7 @@ pub(super) fn lower_module( } } Item::ExternFunction(func) => { + let type_params = build_type_params(&func.type_params)?; let params: Result, TypeError> = func .params .iter() @@ -139,7 +142,7 @@ pub(super) fn lower_module( param.name.span, )); }; - let lowered = lower_type(ty, use_map, stdlib)?; + let lowered = lower_type(ty, use_map, stdlib, &type_params)?; Ok(HirParam { local_id: LocalId(0), ty: hir_type_for(lowered, &ctx, ty.span())?, @@ -148,9 +151,10 @@ pub(super) fn lower_module( .collect(); hir_extern_functions.push(HirExternFunction { name: func.name.item.clone(), + type_params: func.type_params.iter().map(|p| p.item.clone()).collect(), params: params?, ret_ty: { - let lowered = lower_type(&func.ret, use_map, stdlib)?; + let lowered = lower_type(&func.ret, use_map, stdlib, &type_params)?; hir_type_for(lowered, &ctx, func.ret.span())? }, }); @@ -162,11 +166,12 @@ pub(super) fn lower_module( let mut hir_structs = Vec::new(); for item in &module.items { if let Item::Struct(decl) = item { + let type_params = build_type_params(&decl.type_params)?; let fields: Result, TypeError> = decl .fields .iter() .map(|f| { - let lowered = lower_type(&f.ty, use_map, stdlib)?; + let lowered = lower_type(&f.ty, use_map, stdlib, &type_params)?; Ok(HirField { name: f.name.item.clone(), ty: hir_type_for(lowered, &ctx, f.ty.span())?, @@ -175,6 +180,7 @@ pub(super) fn lower_module( .collect(); hir_structs.push(HirStruct { name: decl.name.item.clone(), + type_params: decl.type_params.iter().map(|p| p.item.clone()).collect(), fields: fields?, is_opaque: decl.is_opaque || decl.is_capability, }); @@ -184,6 +190,7 @@ pub(super) fn lower_module( let mut hir_enums = Vec::new(); for item in &module.items { if let Item::Enum(decl) = item { + let type_params = build_type_params(&decl.type_params)?; let variants: Result, TypeError> = decl .variants .iter() @@ -192,7 +199,7 @@ pub(super) fn lower_module( .payload .as_ref() .map(|ty| { - let lowered = lower_type(ty, use_map, stdlib)?; + let lowered = lower_type(ty, use_map, stdlib, &type_params)?; hir_type_for(lowered, &ctx, ty.span()) }) .transpose()?; @@ -204,6 +211,7 @@ pub(super) fn lower_module( .collect(); hir_enums.push(HirEnum { name: decl.name.item.clone(), + type_params: decl.type_params.iter().map(|p| p.item.clone()).collect(), variants: variants?, }); } @@ -226,6 +234,8 @@ fn lower_function(func: &Function, ctx: &mut LoweringCtx) -> Result, TypeError> = func .params @@ -237,7 +247,7 @@ fn lower_function(func: &Function, ctx: &mut LoweringCtx) -> Result Result Result, TypeError> { + args.iter() + .map(|arg| lower_type(arg, ctx.use_map, ctx.stdlib, &ctx.type_params)) + .collect::, _>>() +} + fn abi_type_for(ty: &Ty, ctx: &LoweringCtx, span: Span) -> Result { use super::BuiltinType; match ty { @@ -417,6 +436,7 @@ fn abi_type_for(ty: &Ty, ctx: &LoweringCtx, span: Span) -> Result Ok(AbiType::Ptr), Ty::Ref(inner) => abi_type_for(inner, ctx, span), + Ty::Param(_) => Ok(AbiType::Ptr), Ty::Path(name, args) => { if name == "Result" && args.len() == 2 { let ok = abi_type_for(&args[0], ctx, span)?; @@ -515,6 +535,7 @@ fn lower_expr(expr: &Expr, ctx: &mut LoweringCtx, ret_ty: &Ty) -> Result Result Result Result Result { + let struct_ty = type_of_ast_expr(expr, ctx, ret_ty)?; let type_name = resolve_type_name(&lit.path, ctx.use_map, ctx.stdlib); let key = if lit.path.segments.len() == 1 { if ctx.stdlib.types.contains_key(&lit.path.segments[0].item) { @@ -843,7 +868,6 @@ fn lower_expr(expr: &Expr, ctx: &mut LoweringCtx, ret_ty: &Ty) -> Result), + /// Generic type parameter. + Param(String), +} + +pub(super) fn build_type_params(params: &[Ident]) -> Result, TypeError> { + let mut set = HashSet::new(); + for param in params { + let name = param.item.as_str(); + if RESERVED_TYPE_PARAMS.contains(&name) { + return Err(TypeError::new( + format!("type parameter `{}` is reserved", param.item), + param.span, + )); + } + if !set.insert(param.item.clone()) { + return Err(TypeError::new( + format!("duplicate type parameter `{}`", param.item), + param.span, + )); + } + } + Ok(set) +} + +fn merge_type_params( + base: &HashSet, + params: &[Ident], +) -> Result, TypeError> { + let mut set = base.clone(); + for param in params { + let name = param.item.as_str(); + if RESERVED_TYPE_PARAMS.contains(&name) { + return Err(TypeError::new( + format!("type parameter `{}` is reserved", param.item), + param.span, + )); + } + if set.contains(¶m.item) { + return Err(TypeError::new( + format!("duplicate type parameter `{}`", param.item), + param.span, + )); + } + set.insert(param.item.clone()); + } + Ok(set) } /// Built-in primitive types. @@ -91,6 +142,7 @@ type FunctionTypeTables = HashMap; /// Resolved signature for a function. #[derive(Debug, Clone)] struct FunctionSig { + type_params: Vec, params: Vec, ret: Ty, module: String, @@ -100,6 +152,7 @@ struct FunctionSig { /// Metadata about a struct needed by the type checker. #[derive(Debug, Clone)] struct StructInfo { + type_params: Vec, fields: HashMap, is_opaque: bool, is_capability: bool, @@ -110,6 +163,7 @@ struct StructInfo { /// Metadata about an enum needed by the type checker. #[derive(Debug, Clone)] struct EnumInfo { + type_params: Vec, variants: Vec, payloads: HashMap>, } @@ -366,10 +420,11 @@ fn resolve_impl_target( use_map: &UseMap, stdlib: &StdlibIndex, struct_map: &HashMap, + type_params: &HashSet, module_name: &str, span: Span, ) -> Result<(String, String, Ty), TypeError> { - let target_ty = lower_type(target, use_map, stdlib)?; + let target_ty = lower_type(target, use_map, stdlib, type_params)?; let (impl_module, type_name) = match &target_ty { Ty::Path(target_name, _target_args) => { if let Some(info) = struct_map.get(target_name) { @@ -421,6 +476,7 @@ fn validate_impl_method( stdlib: &StdlibIndex, struct_map: &HashMap, enum_map: &HashMap, + type_params: &HashSet, _span: Span, ) -> Result, TypeError> { if method.name.item.contains("__") { @@ -450,7 +506,7 @@ fn validate_impl_method( let mut receiver_is_ref = false; if let Some(ty) = &first_param.ty { - let lowered = lower_type(ty, use_map, stdlib)?; + let lowered = lower_type(ty, use_map, stdlib, type_params)?; if lowered != expected && lowered != expected_ptr && lowered != expected_ref { return Err(TypeError::new( format!( @@ -473,7 +529,7 @@ fn validate_impl_method( } } - let ret_ty = lower_type(&method.ret, use_map, stdlib)?; + let ret_ty = lower_type(&method.ret, use_map, stdlib, type_params)?; if receiver_is_ref && type_contains_capability(&ret_ty, struct_map, enum_map) { return Err(TypeError::new( "methods returning capabilities must take `self` by value".to_string(), @@ -484,13 +540,19 @@ fn validate_impl_method( Ok(params) } -fn desugar_impl_method(type_name: &str, method: &Function, params: Vec) -> Function { +fn desugar_impl_method( + type_name: &str, + method: &Function, + params: Vec, + type_params: Vec, +) -> Function { let name = Spanned::new( format!("{type_name}__{}", method.name.item), method.name.span, ); Function { name, + type_params, params, ret: method.ret.clone(), body: method.body.clone(), @@ -508,11 +570,13 @@ fn desugar_impl_methods( struct_map: &HashMap, enum_map: &HashMap, ) -> Result, TypeError> { + let impl_type_params = build_type_params(&impl_block.type_params)?; let (_impl_module, type_name, target_ty) = resolve_impl_target( &impl_block.target, use_map, stdlib, struct_map, + &impl_type_params, module_name, impl_block.span, )?; @@ -525,6 +589,9 @@ fn desugar_impl_methods( method.name.span, )); } + let method_type_params = merge_type_params(&impl_type_params, &method.type_params)?; + let mut combined_type_params = impl_block.type_params.clone(); + combined_type_params.extend(method.type_params.clone()); let params = validate_impl_method( &type_name, &target_ty, @@ -535,26 +602,56 @@ fn desugar_impl_methods( stdlib, struct_map, enum_map, + &method_type_params, method.span, )?; - methods.push(desugar_impl_method(&type_name, method, params)); + methods.push(desugar_impl_method( + &type_name, + method, + params, + combined_type_params, + )); } Ok(methods) } /// Convert AST types into resolved Ty (builtins + fully qualified paths). -fn lower_type(ty: &Type, use_map: &UseMap, stdlib: &StdlibIndex) -> Result { +fn lower_type( + ty: &Type, + use_map: &UseMap, + stdlib: &StdlibIndex, + type_params: &HashSet, +) -> Result { match ty { - Type::Ptr { target, .. } => Ok(Ty::Ptr(Box::new(lower_type(target, use_map, stdlib)?))), - Type::Ref { target, .. } => Ok(Ty::Ref(Box::new(lower_type(target, use_map, stdlib)?))), + Type::Ptr { target, .. } => Ok(Ty::Ptr(Box::new(lower_type( + target, + use_map, + stdlib, + type_params, + )?))), + Type::Ref { target, .. } => Ok(Ty::Ref(Box::new(lower_type( + target, + use_map, + stdlib, + type_params, + )?))), Type::Path { path, args, .. } => { let resolved = resolve_path(path, use_map); let path_segments = resolved.iter().map(|seg| seg.as_str()).collect::>(); - let args = args + let args: Vec = args .iter() - .map(|arg| lower_type(arg, use_map, stdlib)) + .map(|arg| lower_type(arg, use_map, stdlib, type_params)) .collect::>()?; if path_segments.len() == 1 { + if type_params.contains(path_segments[0]) { + if !args.is_empty() { + return Err(TypeError::new( + format!("type parameter `{}` cannot take arguments", path_segments[0]), + path.span, + )); + } + return Ok(Ty::Param(path_segments[0].to_string())); + } let builtin = match path_segments[0] { "i32" => Some(BuiltinType::I32), "i64" => Some(BuiltinType::I64), @@ -655,6 +752,7 @@ fn type_contains_capability_inner( ) -> bool { match ty { Ty::Builtin(_) | Ty::Ptr(_) | Ty::Ref(_) => false, + Ty::Param(_) => true, Ty::Path(name, args) => { if name == "Result" { return args @@ -736,6 +834,7 @@ fn type_kind_inner( ) -> TypeKind { match ty { Ty::Builtin(_) | Ty::Ptr(_) | Ty::Ref(_) => TypeKind::Unrestricted, + Ty::Param(_) => TypeKind::Affine, Ty::Path(name, args) => { if name == "Result" { return args.iter().fold(TypeKind::Unrestricted, |acc, arg| { @@ -770,6 +869,56 @@ fn type_kind_inner( } } +fn validate_type_args( + ty: &Ty, + struct_map: &HashMap, + enum_map: &HashMap, + span: Span, +) -> Result<(), TypeError> { + match ty { + Ty::Builtin(_) | Ty::Param(_) => Ok(()), + Ty::Ptr(inner) | Ty::Ref(inner) => validate_type_args(inner, struct_map, enum_map, span), + Ty::Path(name, args) => { + if name == "Result" { + if args.len() != 2 { + return Err(TypeError::new( + format!("Result expects 2 type arguments, found {}", args.len()), + span, + )); + } + } else if let Some(info) = struct_map.get(name) { + if args.len() != info.type_params.len() { + return Err(TypeError::new( + format!( + "type `{}` expects {} type argument(s), found {}", + name, + info.type_params.len(), + args.len() + ), + span, + )); + } + } else if let Some(info) = enum_map.get(name) { + if args.len() != info.type_params.len() { + return Err(TypeError::new( + format!( + "type `{}` expects {} type argument(s), found {}", + name, + info.type_params.len(), + args.len() + ), + span, + )); + } + } + for arg in args { + validate_type_args(arg, struct_map, enum_map, span)?; + } + Ok(()) + } + } +} + pub fn type_check(module: &Module) -> Result { type_check_program(module, &[], &[]) } @@ -798,6 +947,8 @@ pub fn type_check_program( .map_err(|err| err.with_context("while collecting structs"))?; let enum_map = collect::collect_enums(&modules, &module_name, &stdlib_index) .map_err(|err| err.with_context("while collecting enums"))?; + collect::validate_type_defs(&modules, &stdlib_index, &struct_map, &enum_map) + .map_err(|err| err.with_context("while validating type arguments"))?; collect::validate_copy_structs(&modules, &struct_map, &enum_map, &stdlib_index) .map_err(|err| err.with_context("while validating copy structs"))?; let functions = collect::collect_functions( @@ -915,11 +1066,12 @@ pub fn type_check_program( ) .map_err(|err| err.with_context(format!("in module `{}`", module.name)))?; - Ok(crate::hir::HirProgram { + let hir_program = crate::hir::HirProgram { entry: hir_entry, user_modules: hir_user_modules?, stdlib: hir_stdlib?, - }) + }; + monomorphize::monomorphize_program(hir_program) } pub(super) trait SpanExt { diff --git a/capc/src/typeck/monomorphize.rs b/capc/src/typeck/monomorphize.rs new file mode 100644 index 0000000..763223d --- /dev/null +++ b/capc/src/typeck/monomorphize.rs @@ -0,0 +1,1050 @@ +use std::collections::{HashMap, HashSet}; + +use crate::abi::AbiType; +use crate::error::TypeError; +use crate::hir::*; +use crate::typeck::Ty; +use crate::ast::Span; + +const DUMMY_SPAN: Span = Span { start: 0, end: 0 }; + +#[derive(Clone)] +struct ModuleOut { + name: String, + functions: Vec, + extern_functions: Vec, + structs: Vec, + enums: Vec, +} + +impl ModuleOut { + fn new(name: String) -> Self { + Self { + name, + functions: Vec::new(), + extern_functions: Vec::new(), + structs: Vec::new(), + enums: Vec::new(), + } + } +} + +#[derive(Clone)] +struct FunctionInstance { + module: String, + base_name: String, + type_args: Vec, +} + +struct MonoCtx { + program: HirProgram, + functions: HashMap, + externs: HashMap, + structs: HashMap, + enums: HashMap, + out_modules: HashMap, + generated_functions: HashSet, + generated_externs: HashSet, + generated_structs: HashSet, + generated_enums: HashSet, + queue: Vec, +} + +pub(super) fn monomorphize_program(program: HirProgram) -> Result { + let mut ctx = MonoCtx::new(program)?; + ctx.seed_roots()?; + ctx.process_queue()?; + Ok(ctx.into_program()) +} + +impl MonoCtx { + fn new(program: HirProgram) -> Result { + let mut out_modules = HashMap::new(); + let mut functions = HashMap::new(); + let mut externs = HashMap::new(); + let mut structs = HashMap::new(); + let mut enums = HashMap::new(); + + let mut register_module = |module: &HirModule| { + out_modules + .entry(module.name.clone()) + .or_insert_with(|| ModuleOut::new(module.name.clone())); + for func in &module.functions { + let key = qualify(&module.name, &func.name); + functions.insert(key, func.clone()); + } + for func in &module.extern_functions { + let key = qualify(&module.name, &func.name); + externs.insert(key, func.clone()); + } + for decl in &module.structs { + let key = qualify(&module.name, &decl.name); + structs.insert(key, decl.clone()); + } + for decl in &module.enums { + let key = qualify(&module.name, &decl.name); + enums.insert(key, decl.clone()); + } + }; + + register_module(&program.entry); + for module in &program.user_modules { + register_module(module); + } + for module in &program.stdlib { + register_module(module); + } + + Ok(Self { + program, + functions, + externs, + structs, + enums, + out_modules, + generated_functions: HashSet::new(), + generated_externs: HashSet::new(), + generated_structs: HashSet::new(), + generated_enums: HashSet::new(), + queue: Vec::new(), + }) + } + + fn seed_roots(&mut self) -> Result<(), TypeError> { + let modules = self.all_modules().into_iter().cloned().collect::>(); + for module in modules { + for func in &module.functions { + if !func.type_params.is_empty() { + continue; + } + let name = func.name.clone(); + let mono = self.mono_function(&module.name, func, &HashMap::new(), name)?; + self.push_function(&module.name, mono); + } + for func in &module.extern_functions { + if !func.type_params.is_empty() { + continue; + } + let name = func.name.clone(); + let mono = self.mono_extern(&module.name, func, &HashMap::new(), name)?; + self.push_extern(&module.name, mono); + } + for decl in &module.structs { + if !decl.type_params.is_empty() { + continue; + } + let name = decl.name.clone(); + let mono = self.mono_struct(&module.name, decl, &HashMap::new(), name)?; + self.push_struct(&module.name, mono); + } + for decl in &module.enums { + if !decl.type_params.is_empty() { + continue; + } + let name = decl.name.clone(); + let mono = self.mono_enum(&module.name, decl, &HashMap::new(), name)?; + self.push_enum(&module.name, mono); + } + } + Ok(()) + } + + fn all_modules(&self) -> Vec<&HirModule> { + let mut modules = Vec::new(); + modules.push(&self.program.entry); + for module in &self.program.user_modules { + modules.push(module); + } + for module in &self.program.stdlib { + modules.push(module); + } + modules + } + + fn process_queue(&mut self) -> Result<(), TypeError> { + while let Some(instance) = self.queue.pop() { + let key = qualify(&instance.module, &instance.base_name); + if let Some(func) = self.functions.get(&key).cloned() { + let new_name = mangle_name(&instance.base_name, &instance.type_args); + if self.generated_functions.contains(&qualify(&instance.module, &new_name)) { + continue; + } + let subs = build_substitution(&func.type_params, &instance.type_args, DUMMY_SPAN)?; + let mono = self.mono_function(&instance.module, &func, &subs, new_name)?; + self.push_function(&instance.module, mono); + continue; + } + if let Some(func) = self.externs.get(&key).cloned() { + let new_name = mangle_name(&instance.base_name, &instance.type_args); + if self.generated_externs.contains(&qualify(&instance.module, &new_name)) { + continue; + } + let subs = build_substitution(&func.type_params, &instance.type_args, DUMMY_SPAN)?; + let mono = self.mono_extern(&instance.module, &func, &subs, new_name)?; + self.push_extern(&instance.module, mono); + continue; + } + return Err(TypeError::new( + format!("unknown function `{}`", key), + DUMMY_SPAN, + )); + } + Ok(()) + } + + fn into_program(self) -> HirProgram { + let mut modules = self.out_modules.into_values().collect::>(); + modules.sort_by(|a, b| a.name.cmp(&b.name)); + let mut entry = None; + let mut user_modules = Vec::new(); + let mut stdlib = Vec::new(); + for module in modules { + if module.name == self.program.entry.name { + entry = Some(module); + continue; + } + if self + .program + .stdlib + .iter() + .any(|m| m.name == module.name) + { + stdlib.push(module); + } else { + user_modules.push(module); + } + } + let entry = entry.unwrap_or_else(|| ModuleOut::new(self.program.entry.name)); + HirProgram { + entry: entry.into(), + user_modules: user_modules.into_iter().map(Into::into).collect(), + stdlib: stdlib.into_iter().map(Into::into).collect(), + } + } + + fn push_function(&mut self, module: &str, func: HirFunction) { + let key = qualify(module, &func.name); + if self.generated_functions.insert(key) { + if let Some(out) = self.out_modules.get_mut(module) { + out.functions.push(func); + } + } + } + + fn push_extern(&mut self, module: &str, func: HirExternFunction) { + let key = qualify(module, &func.name); + if self.generated_externs.insert(key) { + if let Some(out) = self.out_modules.get_mut(module) { + out.extern_functions.push(func); + } + } + } + + fn push_struct(&mut self, module: &str, decl: HirStruct) { + let key = qualify(module, &decl.name); + if self.generated_structs.insert(key) { + self.structs.insert(qualify(module, &decl.name), decl.clone()); + if let Some(out) = self.out_modules.get_mut(module) { + out.structs.push(decl); + } + } + } + + fn push_enum(&mut self, module: &str, decl: HirEnum) { + let key = qualify(module, &decl.name); + if self.generated_enums.insert(key) { + self.enums.insert(qualify(module, &decl.name), decl.clone()); + if let Some(out) = self.out_modules.get_mut(module) { + out.enums.push(decl); + } + } + } + + fn mono_function( + &mut self, + module: &str, + func: &HirFunction, + subs: &HashMap, + new_name: String, + ) -> Result { + let params: Result, TypeError> = func + .params + .iter() + .map(|param| { + let ty = self.mono_hir_type(module, ¶m.ty, subs)?; + Ok(HirParam { + local_id: param.local_id, + ty, + }) + }) + .collect(); + let ret_ty = self.mono_hir_type(module, &func.ret_ty, subs)?; + let body = self.mono_block(module, &func.body, subs)?; + Ok(HirFunction { + name: new_name, + type_params: Vec::new(), + params: params?, + ret_ty, + body, + }) + } + + fn mono_extern( + &mut self, + module: &str, + func: &HirExternFunction, + subs: &HashMap, + new_name: String, + ) -> Result { + let params: Result, TypeError> = func + .params + .iter() + .map(|param| { + let ty = self.mono_hir_type(module, ¶m.ty, subs)?; + Ok(HirParam { + local_id: param.local_id, + ty, + }) + }) + .collect(); + let ret_ty = self.mono_hir_type(module, &func.ret_ty, subs)?; + Ok(HirExternFunction { + name: new_name, + type_params: Vec::new(), + params: params?, + ret_ty, + }) + } + + fn mono_struct( + &mut self, + module: &str, + decl: &HirStruct, + subs: &HashMap, + new_name: String, + ) -> Result { + let fields: Result, TypeError> = decl + .fields + .iter() + .map(|field| { + let ty = self.mono_hir_type(module, &field.ty, subs)?; + Ok(HirField { + name: field.name.clone(), + ty, + }) + }) + .collect(); + Ok(HirStruct { + name: new_name, + type_params: Vec::new(), + fields: fields?, + is_opaque: decl.is_opaque, + }) + } + + fn mono_enum( + &mut self, + module: &str, + decl: &HirEnum, + subs: &HashMap, + new_name: String, + ) -> Result { + let variants: Result, TypeError> = decl + .variants + .iter() + .map(|variant| { + let payload = match &variant.payload { + Some(payload) => Some(self.mono_hir_type(module, payload, subs)?), + None => None, + }; + Ok(HirEnumVariant { + name: variant.name.clone(), + payload, + }) + }) + .collect(); + Ok(HirEnum { + name: new_name, + type_params: Vec::new(), + variants: variants?, + }) + } + + fn mono_block( + &mut self, + module: &str, + block: &HirBlock, + subs: &HashMap, + ) -> Result { + let stmts: Result, TypeError> = block + .stmts + .iter() + .map(|stmt| self.mono_stmt(module, stmt, subs)) + .collect(); + Ok(HirBlock { stmts: stmts? }) + } + + fn mono_stmt( + &mut self, + module: &str, + stmt: &HirStmt, + subs: &HashMap, + ) -> Result { + match stmt { + HirStmt::Let(let_stmt) => { + let expr = self.mono_expr(module, &let_stmt.expr, subs)?; + let ty = self.mono_hir_type(module, &let_stmt.ty, subs)?; + Ok(HirStmt::Let(HirLetStmt { + local_id: let_stmt.local_id, + ty, + expr, + span: let_stmt.span, + })) + } + HirStmt::Assign(assign) => { + let expr = self.mono_expr(module, &assign.expr, subs)?; + Ok(HirStmt::Assign(HirAssignStmt { + local_id: assign.local_id, + expr, + span: assign.span, + })) + } + HirStmt::Return(ret) => { + let expr = ret + .expr + .as_ref() + .map(|expr| self.mono_expr(module, expr, subs)) + .transpose()?; + Ok(HirStmt::Return(HirReturnStmt { + expr, + span: ret.span, + })) + } + HirStmt::If(if_stmt) => { + let cond = self.mono_expr(module, &if_stmt.cond, subs)?; + let then_block = self.mono_block(module, &if_stmt.then_block, subs)?; + let else_block = if_stmt + .else_block + .as_ref() + .map(|block| self.mono_block(module, block, subs)) + .transpose()?; + Ok(HirStmt::If(HirIfStmt { + cond, + then_block, + else_block, + span: if_stmt.span, + })) + } + HirStmt::While(while_stmt) => { + let cond = self.mono_expr(module, &while_stmt.cond, subs)?; + let body = self.mono_block(module, &while_stmt.body, subs)?; + Ok(HirStmt::While(HirWhileStmt { + cond, + body, + span: while_stmt.span, + })) + } + HirStmt::Expr(expr_stmt) => { + let expr = self.mono_expr(module, &expr_stmt.expr, subs)?; + Ok(HirStmt::Expr(HirExprStmt { + expr, + span: expr_stmt.span, + })) + } + } + } + + fn mono_expr( + &mut self, + module: &str, + expr: &HirExpr, + subs: &HashMap, + ) -> Result { + match expr { + HirExpr::Literal(lit) => Ok(HirExpr::Literal(HirLiteral { + value: lit.value.clone(), + ty: self.mono_hir_type(module, &lit.ty, subs)?, + span: lit.span, + })), + HirExpr::Local(local) => Ok(HirExpr::Local(HirLocal { + local_id: local.local_id, + ty: self.mono_hir_type(module, &local.ty, subs)?, + span: local.span, + })), + HirExpr::EnumVariant(variant) => { + let enum_ty = self.mono_hir_type(module, &variant.enum_ty, subs)?; + let payload = variant + .payload + .as_ref() + .map(|expr| self.mono_expr(module, expr, subs)) + .transpose()? + .map(Box::new); + Ok(HirExpr::EnumVariant(HirEnumVariantExpr { + enum_ty, + variant_name: variant.variant_name.clone(), + payload, + span: variant.span, + })) + } + HirExpr::Call(call) => { + let args: Result, TypeError> = call + .args + .iter() + .map(|arg| self.mono_expr(module, arg, subs)) + .collect(); + let args = args?; + let ret_ty = self.mono_hir_type(module, &call.ret_ty, subs)?; + match &call.callee { + ResolvedCallee::Intrinsic(id) => { + return Ok(HirExpr::Call(HirCall { + callee: ResolvedCallee::Intrinsic(*id), + type_args: Vec::new(), + args, + ret_ty, + span: call.span, + })); + } + ResolvedCallee::Function { module, name, .. } => { + let key = qualify(module, name); + if let Some(func) = self.functions.get(&key).cloned() { + let (new_name, symbol, type_args) = self.mono_callee( + module, + &func, + &args, + &call.type_args, + subs, + )?; + let callee = ResolvedCallee::Function { + module: module.clone(), + name: new_name, + symbol, + }; + return Ok(HirExpr::Call(HirCall { + callee, + type_args, + args, + ret_ty, + span: call.span, + })); + } + if let Some(func) = self.externs.get(&key).cloned() { + let (new_name, symbol, type_args) = self.mono_callee( + module, + &func, + &args, + &call.type_args, + subs, + )?; + let callee = ResolvedCallee::Function { + module: module.clone(), + name: new_name, + symbol, + }; + return Ok(HirExpr::Call(HirCall { + callee, + type_args, + args, + ret_ty, + span: call.span, + })); + } + return Err(TypeError::new( + format!("unknown function `{}`", key), + DUMMY_SPAN, + )); + } + } + } + HirExpr::FieldAccess(access) => Ok(HirExpr::FieldAccess(HirFieldAccess { + object: Box::new(self.mono_expr(module, &access.object, subs)?), + field_name: access.field_name.clone(), + field_ty: self.mono_hir_type(module, &access.field_ty, subs)?, + span: access.span, + })), + HirExpr::StructLiteral(literal) => { + let struct_ty = self.mono_hir_type(module, &literal.struct_ty, subs)?; + let fields: Result, TypeError> = literal + .fields + .iter() + .map(|field| { + Ok(HirStructLiteralField { + name: field.name.clone(), + expr: self.mono_expr(module, &field.expr, subs)?, + }) + }) + .collect(); + Ok(HirExpr::StructLiteral(HirStructLiteral { + struct_ty, + fields: fields?, + span: literal.span, + })) + } + HirExpr::Unary(unary) => Ok(HirExpr::Unary(HirUnary { + op: unary.op.clone(), + expr: Box::new(self.mono_expr(module, &unary.expr, subs)?), + ty: self.mono_hir_type(module, &unary.ty, subs)?, + span: unary.span, + })), + HirExpr::Binary(binary) => Ok(HirExpr::Binary(HirBinary { + op: binary.op.clone(), + left: Box::new(self.mono_expr(module, &binary.left, subs)?), + right: Box::new(self.mono_expr(module, &binary.right, subs)?), + ty: self.mono_hir_type(module, &binary.ty, subs)?, + span: binary.span, + })), + HirExpr::Match(m) => { + let expr = self.mono_expr(module, &m.expr, subs)?; + let arms: Result, TypeError> = m + .arms + .iter() + .map(|arm| { + Ok(HirMatchArm { + pattern: arm.pattern.clone(), + body: self.mono_block(module, &arm.body, subs)?, + }) + }) + .collect(); + Ok(HirExpr::Match(HirMatch { + expr: Box::new(expr), + arms: arms?, + result_ty: self.mono_hir_type(module, &m.result_ty, subs)?, + span: m.span, + })) + } + HirExpr::Try(t) => { + let expr = self.mono_expr(module, &t.expr, subs)?; + Ok(HirExpr::Try(HirTry { + expr: Box::new(expr), + ok_ty: self.mono_hir_type(module, &t.ok_ty, subs)?, + ret_ty: self.mono_hir_type(module, &t.ret_ty, subs)?, + span: t.span, + })) + } + } + } + + fn mono_callee( + &mut self, + module: &str, + func: &impl GenericSig, + args: &[HirExpr], + explicit_type_args: &[Ty], + subs: &HashMap, + ) -> Result<(String, String, Vec), TypeError> { + if func.type_params().is_empty() { + if !explicit_type_args.is_empty() { + return Err(TypeError::new( + "function does not accept type arguments".to_string(), + DUMMY_SPAN, + )); + } + let name = func.name().to_string(); + let symbol = function_symbol(module, &name); + return Ok((name, symbol, Vec::new())); + } + + let explicit_args: Vec = explicit_type_args + .iter() + .map(|arg| self.mono_ty(module, arg, subs)) + .collect::>()?; + + let mut inferred = HashMap::new(); + if !explicit_args.is_empty() { + for (name, arg) in func.type_params().iter().zip(explicit_args.iter()) { + inferred.insert(name.clone(), arg.clone()); + } + } + for (param, arg) in func.params().iter().zip(args.iter()) { + match_type_params(¶m.ty.ty, &arg.ty().ty, &mut inferred, DUMMY_SPAN)?; + } + for name in func.type_params() { + if !inferred.contains_key(name) { + return Err(TypeError::new( + format!("missing type argument for `{name}`"), + DUMMY_SPAN, + )); + } + } + let ordered_args = func + .type_params() + .iter() + .map(|name| inferred.get(name).cloned().unwrap()) + .collect::>(); + let new_name = mangle_name(func.name(), &ordered_args); + let symbol = function_symbol(module, &new_name); + self.queue.push(FunctionInstance { + module: module.to_string(), + base_name: func.name().to_string(), + type_args: ordered_args.clone(), + }); + Ok((new_name, symbol, Vec::new())) + } + + fn mono_hir_type( + &mut self, + module: &str, + ty: &HirType, + subs: &HashMap, + ) -> Result { + let mono_ty = self.mono_ty(module, &ty.ty, subs)?; + let abi = self.abi_type_for(module, &mono_ty)?; + Ok(HirType { ty: mono_ty, abi }) + } + + fn mono_ty( + &mut self, + module: &str, + ty: &Ty, + subs: &HashMap, + ) -> Result { + match ty { + Ty::Param(name) => subs + .get(name) + .cloned() + .ok_or_else(|| TypeError::new(format!("unbound type parameter `{name}`"), DUMMY_SPAN)), + Ty::Builtin(_) => Ok(ty.clone()), + Ty::Ptr(inner) => Ok(Ty::Ptr(Box::new(self.mono_ty(module, inner, subs)?))), + Ty::Ref(inner) => Ok(Ty::Ref(Box::new(self.mono_ty(module, inner, subs)?))), + Ty::Path(name, args) => { + if name == "Result" { + let args = args + .iter() + .map(|arg| self.mono_ty(module, arg, subs)) + .collect::, _>>()?; + return Ok(Ty::Path(name.clone(), args)); + } + let (type_module, base_name, qualified) = split_name(module, name); + let args = args + .iter() + .map(|arg| self.mono_ty(module, arg, subs)) + .collect::, _>>()?; + let qualified_key = qualify(&type_module, &base_name); + if let Some(struct_def) = self.structs.get(&qualified_key).cloned() { + if !struct_def.type_params.is_empty() { + let new_name = self.ensure_struct_instance(&type_module, &struct_def, &args)?; + let name = if qualified { + qualify(&type_module, &new_name) + } else { + new_name + }; + return Ok(Ty::Path(name, Vec::new())); + } + } + if let Some(enum_def) = self.enums.get(&qualified_key).cloned() { + if !enum_def.type_params.is_empty() { + let new_name = self.ensure_enum_instance(&type_module, &enum_def, &args)?; + let name = if qualified { + qualify(&type_module, &new_name) + } else { + new_name + }; + return Ok(Ty::Path(name, Vec::new())); + } + } + Ok(Ty::Path(name.clone(), args)) + } + } + } + + fn ensure_struct_instance( + &mut self, + module: &str, + decl: &HirStruct, + args: &[Ty], + ) -> Result { + if decl.type_params.is_empty() { + return Ok(decl.name.clone()); + } + if decl.type_params.len() != args.len() { + return Err(TypeError::new( + format!( + "type `{}` expects {} type argument(s), found {}", + decl.name, + decl.type_params.len(), + args.len() + ), + DUMMY_SPAN, + )); + } + let name = mangle_name(&decl.name, args); + let qualified = qualify(module, &name); + if self.generated_structs.contains(&qualified) { + return Ok(name); + } + let subs = build_substitution(&decl.type_params, args, DUMMY_SPAN)?; + let mono = self.mono_struct(module, decl, &subs, name.clone())?; + self.push_struct(module, mono); + Ok(name) + } + + fn ensure_enum_instance( + &mut self, + module: &str, + decl: &HirEnum, + args: &[Ty], + ) -> Result { + if decl.type_params.is_empty() { + return Ok(decl.name.clone()); + } + if decl.type_params.len() != args.len() { + return Err(TypeError::new( + format!( + "type `{}` expects {} type argument(s), found {}", + decl.name, + decl.type_params.len(), + args.len() + ), + DUMMY_SPAN, + )); + } + let name = mangle_name(&decl.name, args); + let qualified = qualify(module, &name); + if self.generated_enums.contains(&qualified) { + return Ok(name); + } + let subs = build_substitution(&decl.type_params, args, DUMMY_SPAN)?; + let mono = self.mono_enum(module, decl, &subs, name.clone())?; + self.push_enum(module, mono); + Ok(name) + } + + fn abi_type_for(&self, module: &str, ty: &Ty) -> Result { + use crate::typeck::BuiltinType; + match ty { + Ty::Builtin(b) => match b { + BuiltinType::I32 => Ok(AbiType::I32), + BuiltinType::I64 => Err(TypeError::new( + "i64 is not supported by the current codegen backend".to_string(), + DUMMY_SPAN, + )), + BuiltinType::U32 => Ok(AbiType::U32), + BuiltinType::U8 => Ok(AbiType::U8), + BuiltinType::Bool => Ok(AbiType::Bool), + BuiltinType::String => Ok(AbiType::String), + BuiltinType::Unit => Ok(AbiType::Unit), + }, + Ty::Ptr(_) => Ok(AbiType::Ptr), + Ty::Ref(inner) => self.abi_type_for(module, inner), + Ty::Param(_) => Err(TypeError::new( + "generic type parameters must be monomorphized before codegen".to_string(), + DUMMY_SPAN, + )), + Ty::Path(name, args) => { + if name == "Result" && args.len() == 2 { + let ok = self.abi_type_for(module, &args[0])?; + let err = self.abi_type_for(module, &args[1])?; + return Ok(AbiType::Result(Box::new(ok), Box::new(err))); + } + let (type_module, base_name, _qualified) = split_name(module, name); + let qualified = qualify(&type_module, &base_name); + if let Some(info) = self.structs.get(&qualified) { + return Ok(if info.is_opaque { + AbiType::Handle + } else { + AbiType::Ptr + }); + } + if let Some(info) = self.enums.get(&qualified) { + let _ = info; + return Ok(AbiType::I32); + } + Err(TypeError::new( + format!("unknown type `{}`", name), + DUMMY_SPAN, + )) + } + } + } +} + +impl From for HirModule { + fn from(module: ModuleOut) -> Self { + Self { + name: module.name, + functions: module.functions, + extern_functions: module.extern_functions, + structs: module.structs, + enums: module.enums, + } + } +} + +trait GenericSig { + fn name(&self) -> &str; + fn type_params(&self) -> &Vec; + fn params(&self) -> &Vec; +} + +impl GenericSig for HirFunction { + fn name(&self) -> &str { + &self.name + } + + fn type_params(&self) -> &Vec { + &self.type_params + } + + fn params(&self) -> &Vec { + &self.params + } +} + +impl GenericSig for HirExternFunction { + fn name(&self) -> &str { + &self.name + } + + fn type_params(&self) -> &Vec { + &self.type_params + } + + fn params(&self) -> &Vec { + &self.params + } +} + +fn split_name(module: &str, name: &str) -> (String, String, bool) { + if let Some((mod_part, type_part)) = name.rsplit_once('.') { + (mod_part.to_string(), type_part.to_string(), true) + } else { + (module.to_string(), name.to_string(), false) + } +} + +fn qualify(module: &str, name: &str) -> String { + format!("{module}.{name}") +} + +fn function_symbol(module: &str, name: &str) -> String { + format!("capable_{}", qualify(module, name).replace('.', "_")) +} + +fn build_substitution( + params: &[String], + args: &[Ty], + span: Span, +) -> Result, TypeError> { + if params.len() != args.len() { + return Err(TypeError::new( + format!( + "expected {} type argument(s), found {}", + params.len(), + args.len() + ), + span, + )); + } + let mut map = HashMap::new(); + for (param, arg) in params.iter().zip(args.iter()) { + map.insert(param.clone(), arg.clone()); + } + Ok(map) +} + +fn match_type_params( + expected: &Ty, + actual: &Ty, + subs: &mut HashMap, + span: Span, +) -> Result<(), TypeError> { + match expected { + Ty::Param(name) => { + if let Some(existing) = subs.get(name) { + if existing != actual { + return Err(TypeError::new( + format!( + "conflicting type arguments for `{}`: {existing:?} vs {actual:?}", + name + ), + span, + )); + } + } else { + subs.insert(name.clone(), actual.clone()); + } + Ok(()) + } + Ty::Builtin(_) => { + if expected != actual { + return Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )); + } + Ok(()) + } + Ty::Ptr(inner) => match actual { + Ty::Ptr(actual_inner) => match_type_params(inner, actual_inner, subs, span), + _ => Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )), + }, + Ty::Ref(inner) => match actual { + Ty::Ref(actual_inner) => match_type_params(inner, actual_inner, subs, span), + _ => Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )), + }, + Ty::Path(name, args) => match actual { + Ty::Path(actual_name, actual_args) => { + if name != actual_name || args.len() != actual_args.len() { + return Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )); + } + for (arg, actual_arg) in args.iter().zip(actual_args.iter()) { + match_type_params(arg, actual_arg, subs, span)?; + } + Ok(()) + } + _ => Err(TypeError::new( + format!("type mismatch: expected {expected:?}, found {actual:?}"), + span, + )), + }, + } +} + +fn mangle_name(base: &str, args: &[Ty]) -> String { + if args.is_empty() { + return base.to_string(); + } + let suffix = args + .iter() + .map(mangle_type) + .collect::>() + .join("__"); + format!("{base}__{suffix}") +} + +fn mangle_type(ty: &Ty) -> String { + match ty { + Ty::Builtin(b) => match b { + crate::typeck::BuiltinType::I32 => "i32".to_string(), + crate::typeck::BuiltinType::I64 => "i64".to_string(), + crate::typeck::BuiltinType::U32 => "u32".to_string(), + crate::typeck::BuiltinType::U8 => "u8".to_string(), + crate::typeck::BuiltinType::Bool => "bool".to_string(), + crate::typeck::BuiltinType::String => "string".to_string(), + crate::typeck::BuiltinType::Unit => "unit".to_string(), + }, + Ty::Ptr(inner) => format!("ptr_{}", mangle_type(inner)), + Ty::Ref(inner) => format!("ref_{}", mangle_type(inner)), + Ty::Param(name) => format!("param_{name}"), + Ty::Path(name, args) => { + let mut base = name.replace('.', "_"); + if !args.is_empty() { + let suffix = args + .iter() + .map(mangle_type) + .collect::>() + .join("__"); + base = format!("{base}__{suffix}"); + } + base + } + } +} diff --git a/capc/tests/parser.rs b/capc/tests/parser.rs index e8fdda1..48d632b 100644 --- a/capc/tests/parser.rs +++ b/capc/tests/parser.rs @@ -37,3 +37,9 @@ fn snapshot_doc_comments() { insta::assert_debug_snapshot!(module); } +#[test] +fn snapshot_generics_basic() { + let source = load_program("generics_basic.cap"); + let module = parse_module(&source).expect("parse module"); + insta::assert_debug_snapshot!(module); +} diff --git a/capc/tests/run.rs b/capc/tests/run.rs index 3dd4222..94cb3be 100644 --- a/capc/tests/run.rs +++ b/capc/tests/run.rs @@ -487,6 +487,20 @@ fn run_net_helpers() { assert!(stdout.contains("net err ok"), "stdout was: {stdout:?}"); } +#[test] +fn run_generics_basic() { + let out_dir = make_out_dir("generics_basic"); + let out_dir = out_dir.to_str().expect("utf8 out dir"); + let (code, stdout, _stderr) = run_capc(&[ + "run", + "--out-dir", + out_dir, + "tests/programs/generics_basic.cap", + ]); + assert_eq!(code, 0); + assert!(stdout.contains("42"), "stdout was: {stdout:?}"); +} + #[test] fn run_string_split() { let out_dir = make_out_dir("string_split"); diff --git a/capc/tests/snapshots/parser__snapshot_basic_module.snap b/capc/tests/snapshots/parser__snapshot_basic_module.snap index 552215a..186dda6 100644 --- a/capc/tests/snapshots/parser__snapshot_basic_module.snap +++ b/capc/tests/snapshots/parser__snapshot_basic_module.snap @@ -59,6 +59,7 @@ Module { end: 41, }, }, + type_params: [], params: [ Param { name: Spanned { @@ -154,6 +155,7 @@ Module { end: 89, }, }, + type_args: [], args: [], span: Span { start: 74, @@ -195,6 +197,7 @@ Module { end: 103, }, }, + type_args: [], args: [ Literal( LiteralExpr { diff --git a/capc/tests/snapshots/parser__snapshot_doc_comments.snap b/capc/tests/snapshots/parser__snapshot_doc_comments.snap index 23c9b5e..89b40b6 100644 --- a/capc/tests/snapshots/parser__snapshot_doc_comments.snap +++ b/capc/tests/snapshots/parser__snapshot_doc_comments.snap @@ -37,6 +37,7 @@ Module { end: 100, }, }, + type_params: [], fields: [ Field { name: Spanned { @@ -91,6 +92,7 @@ Module { end: 182, }, }, + type_params: [], variants: [ EnumVariant { name: Spanned { @@ -138,6 +140,7 @@ Module { end: 240, }, }, + type_params: [], params: [ Param { name: Spanned { @@ -247,6 +250,7 @@ Module { end: 326, }, }, + type_params: [], params: [], ret: Path { path: Path { diff --git a/capc/tests/snapshots/parser__snapshot_generics_basic.snap b/capc/tests/snapshots/parser__snapshot_generics_basic.snap new file mode 100644 index 0000000..8a8c139 --- /dev/null +++ b/capc/tests/snapshots/parser__snapshot_generics_basic.snap @@ -0,0 +1,669 @@ +--- +source: capc/tests/parser.rs +expression: module +--- +Module { + package: Safe, + name: Path { + segments: [ + Spanned { + item: "generics_basic", + span: Span { + start: 20, + end: 34, + }, + }, + ], + span: Span { + start: 20, + end: 34, + }, + }, + uses: [ + UseDecl { + path: Path { + segments: [ + Spanned { + item: "sys", + span: Span { + start: 39, + end: 42, + }, + }, + Spanned { + item: "console", + span: Span { + start: 44, + end: 51, + }, + }, + ], + span: Span { + start: 39, + end: 51, + }, + }, + span: Span { + start: 35, + end: 51, + }, + }, + UseDecl { + path: Path { + segments: [ + Spanned { + item: "sys", + span: Span { + start: 56, + end: 59, + }, + }, + Spanned { + item: "system", + span: Span { + start: 61, + end: 67, + }, + }, + ], + span: Span { + start: 56, + end: 67, + }, + }, + span: Span { + start: 52, + end: 67, + }, + }, + ], + items: [ + Struct( + StructDecl { + name: Spanned { + item: "Box", + span: Span { + start: 76, + end: 79, + }, + }, + type_params: [ + Spanned { + item: "T", + span: Span { + start: 80, + end: 81, + }, + }, + ], + fields: [ + Field { + name: Spanned { + item: "value", + span: Span { + start: 87, + end: 92, + }, + }, + ty: Path { + path: Path { + segments: [ + Spanned { + item: "T", + span: Span { + start: 94, + end: 95, + }, + }, + ], + span: Span { + start: 94, + end: 95, + }, + }, + args: [], + span: Span { + start: 94, + end: 95, + }, + }, + }, + ], + is_pub: false, + is_opaque: false, + is_linear: false, + is_copy: false, + is_capability: false, + doc: None, + span: Span { + start: 69, + end: 98, + }, + }, + ), + Function( + Function { + name: Spanned { + item: "id", + span: Span { + start: 103, + end: 105, + }, + }, + type_params: [ + Spanned { + item: "T", + span: Span { + start: 106, + end: 107, + }, + }, + ], + params: [ + Param { + name: Spanned { + item: "value", + span: Span { + start: 109, + end: 114, + }, + }, + ty: Some( + Path { + path: Path { + segments: [ + Spanned { + item: "T", + span: Span { + start: 116, + end: 117, + }, + }, + ], + span: Span { + start: 116, + end: 117, + }, + }, + args: [], + span: Span { + start: 116, + end: 117, + }, + }, + ), + }, + ], + ret: Path { + path: Path { + segments: [ + Spanned { + item: "T", + span: Span { + start: 122, + end: 123, + }, + }, + ], + span: Span { + start: 122, + end: 123, + }, + }, + args: [], + span: Span { + start: 122, + end: 123, + }, + }, + body: Block { + stmts: [ + Return( + ReturnStmt { + expr: Some( + Path( + Path { + segments: [ + Spanned { + item: "value", + span: Span { + start: 135, + end: 140, + }, + }, + ], + span: Span { + start: 135, + end: 140, + }, + }, + ), + ), + span: Span { + start: 128, + end: 142, + }, + }, + ), + ], + span: Span { + start: 124, + end: 142, + }, + }, + is_pub: false, + doc: None, + span: Span { + start: 100, + end: 142, + }, + }, + ), + Function( + Function { + name: Spanned { + item: "main", + span: Span { + start: 151, + end: 155, + }, + }, + type_params: [], + params: [ + Param { + name: Spanned { + item: "rc", + span: Span { + start: 156, + end: 158, + }, + }, + ty: Some( + Path { + path: Path { + segments: [ + Spanned { + item: "RootCap", + span: Span { + start: 160, + end: 167, + }, + }, + ], + span: Span { + start: 160, + end: 167, + }, + }, + args: [], + span: Span { + start: 160, + end: 167, + }, + }, + ), + }, + ], + ret: Path { + path: Path { + segments: [ + Spanned { + item: "i32", + span: Span { + start: 172, + end: 175, + }, + }, + ], + span: Span { + start: 172, + end: 175, + }, + }, + args: [], + span: Span { + start: 172, + end: 175, + }, + }, + body: Block { + stmts: [ + Let( + LetStmt { + name: Spanned { + item: "c", + span: Span { + start: 184, + end: 185, + }, + }, + ty: None, + expr: MethodCall( + MethodCallExpr { + receiver: Path( + Path { + segments: [ + Spanned { + item: "rc", + span: Span { + start: 188, + end: 190, + }, + }, + ], + span: Span { + start: 188, + end: 190, + }, + }, + ), + method: Spanned { + item: "mint_console", + span: Span { + start: 191, + end: 203, + }, + }, + type_args: [], + args: [], + span: Span { + start: 188, + end: 205, + }, + }, + ), + span: Span { + start: 180, + end: 205, + }, + }, + ), + Let( + LetStmt { + name: Spanned { + item: "b", + span: Span { + start: 212, + end: 213, + }, + }, + ty: None, + expr: StructLiteral( + StructLiteralExpr { + path: Path { + segments: [ + Spanned { + item: "Box", + span: Span { + start: 216, + end: 219, + }, + }, + ], + span: Span { + start: 216, + end: 219, + }, + }, + type_args: [ + Path { + path: Path { + segments: [ + Spanned { + item: "i32", + span: Span { + start: 220, + end: 223, + }, + }, + ], + span: Span { + start: 220, + end: 223, + }, + }, + args: [], + span: Span { + start: 220, + end: 223, + }, + }, + ], + fields: [ + StructLiteralField { + name: Spanned { + item: "value", + span: Span { + start: 226, + end: 231, + }, + }, + expr: Literal( + LiteralExpr { + value: Int( + 42, + ), + span: Span { + start: 233, + end: 235, + }, + }, + ), + span: Span { + start: 216, + end: 235, + }, + }, + ], + span: Span { + start: 216, + end: 237, + }, + }, + ), + span: Span { + start: 208, + end: 237, + }, + }, + ), + Let( + LetStmt { + name: Spanned { + item: "v", + span: Span { + start: 244, + end: 245, + }, + }, + ty: None, + expr: Call( + CallExpr { + callee: Path( + Path { + segments: [ + Spanned { + item: "id", + span: Span { + start: 248, + end: 250, + }, + }, + ], + span: Span { + start: 248, + end: 250, + }, + }, + ), + type_args: [ + Path { + path: Path { + segments: [ + Spanned { + item: "i32", + span: Span { + start: 251, + end: 254, + }, + }, + ], + span: Span { + start: 251, + end: 254, + }, + }, + args: [], + span: Span { + start: 251, + end: 254, + }, + }, + ], + args: [ + FieldAccess( + FieldAccessExpr { + object: Path( + Path { + segments: [ + Spanned { + item: "b", + span: Span { + start: 256, + end: 257, + }, + }, + ], + span: Span { + start: 256, + end: 257, + }, + }, + ), + field: Spanned { + item: "value", + span: Span { + start: 258, + end: 263, + }, + }, + span: Span { + start: 256, + end: 263, + }, + }, + ), + ], + span: Span { + start: 248, + end: 264, + }, + }, + ), + span: Span { + start: 240, + end: 264, + }, + }, + ), + Expr( + ExprStmt { + expr: MethodCall( + MethodCallExpr { + receiver: Path( + Path { + segments: [ + Spanned { + item: "c", + span: Span { + start: 267, + end: 268, + }, + }, + ], + span: Span { + start: 267, + end: 268, + }, + }, + ), + method: Spanned { + item: "print_i32", + span: Span { + start: 269, + end: 278, + }, + }, + type_args: [], + args: [ + Path( + Path { + segments: [ + Spanned { + item: "v", + span: Span { + start: 279, + end: 280, + }, + }, + ], + span: Span { + start: 279, + end: 280, + }, + }, + ), + ], + span: Span { + start: 267, + end: 281, + }, + }, + ), + span: Span { + start: 267, + end: 281, + }, + }, + ), + Return( + ReturnStmt { + expr: Some( + Literal( + LiteralExpr { + value: Int( + 0, + ), + span: Span { + start: 291, + end: 292, + }, + }, + ), + ), + span: Span { + start: 284, + end: 294, + }, + }, + ), + ], + span: Span { + start: 176, + end: 294, + }, + }, + is_pub: true, + doc: None, + span: Span { + start: 148, + end: 294, + }, + }, + ), + ], + span: Span { + start: 0, + end: 294, + }, +} diff --git a/capc/tests/snapshots/parser__snapshot_struct_and_match.snap b/capc/tests/snapshots/parser__snapshot_struct_and_match.snap index 80c5ec1..9b8b4c0 100644 --- a/capc/tests/snapshots/parser__snapshot_struct_and_match.snap +++ b/capc/tests/snapshots/parser__snapshot_struct_and_match.snap @@ -59,6 +59,7 @@ Module { end: 43, }, }, + type_params: [], params: [ Param { name: Spanned { @@ -154,6 +155,7 @@ Module { end: 91, }, }, + type_args: [], args: [], span: Span { start: 76, @@ -203,6 +205,7 @@ Module { end: 120, }, }, + type_args: [], args: [ Literal( LiteralExpr { @@ -258,6 +261,7 @@ Module { end: 160, }, }, + type_args: [], args: [ Literal( LiteralExpr { @@ -339,6 +343,7 @@ Module { end: 198, }, }, + type_args: [], args: [ Path( Path { @@ -463,6 +468,7 @@ Module { end: 239, }, }, + type_args: [], args: [ Literal( LiteralExpr { diff --git a/capc/tests/snapshots/parser__snapshot_struct_literal.snap b/capc/tests/snapshots/parser__snapshot_struct_literal.snap index e44682f..a721af5 100644 --- a/capc/tests/snapshots/parser__snapshot_struct_literal.snap +++ b/capc/tests/snapshots/parser__snapshot_struct_literal.snap @@ -30,6 +30,7 @@ Module { end: 34, }, }, + type_params: [], fields: [ Field { name: Spanned { @@ -115,6 +116,7 @@ Module { end: 69, }, }, + type_params: [], params: [], ret: Path { path: Path { @@ -160,6 +162,7 @@ Module { end: 95, }, }, + type_args: [], fields: [ StructLiteralField { name: Spanned { diff --git a/stdlib/sys/buffer.cap b/stdlib/sys/buffer.cap index 3e07b6c..469a549 100644 --- a/stdlib/sys/buffer.cap +++ b/stdlib/sys/buffer.cap @@ -4,8 +4,8 @@ use sys::vec pub copy opaque struct Alloc pub copy opaque struct Buffer -pub copy opaque struct Slice -pub copy opaque struct MutSlice +pub copy opaque struct Slice[T] +pub copy opaque struct MutSlice[T] pub enum AllocErr { Oom diff --git a/tests/programs/generics_basic.cap b/tests/programs/generics_basic.cap new file mode 100644 index 0000000..56ba111 --- /dev/null +++ b/tests/programs/generics_basic.cap @@ -0,0 +1,20 @@ +package safe +module generics_basic +use sys::console +use sys::system + +struct Box[T] { + value: T, +} + +fn id[T](value: T) -> T { + return value +} + +pub fn main(rc: RootCap) -> i32 { + let c = rc.mint_console() + let b = Box[i32]{ value: 42 } + let v = id[i32](b.value) + c.print_i32(v) + return 0 +} From 6e00c6feacedd15fa9433aaf22cad31c459af8e3 Mon Sep 17 00:00:00 2001 From: Jordan Mecom Date: Sun, 28 Dec 2025 18:24:21 -0800 Subject: [PATCH 2/2] Document generics implementation --- docs/generics.md | 94 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 docs/generics.md diff --git a/docs/generics.md b/docs/generics.md new file mode 100644 index 0000000..f094ec9 --- /dev/null +++ b/docs/generics.md @@ -0,0 +1,94 @@ +# Generics + +Capable supports explicit generics for structs, enums, and functions. Generics are +fully monomorphized during compilation: each concrete instantiation produces a +specialized definition in HIR before codegen. + +## Syntax + +Type parameters are declared with `[...]` after the name: + +```cap +struct Box[T] { + value: T, +} + +fn id[T](value: T) -> T { + return value +} +``` + +Type arguments are supplied at use sites: + +```cap +let b = Box[i32]{ value: 42 } +let v = id[i32](b.value) +``` + +## Where generics are allowed + +- Struct declarations: `struct Pair[T, U] { ... }` +- Enum declarations: `enum Option[T] { None, Some(T) }` +- Function declarations: `fn map[T, U](...) -> ...` +- Impl blocks: `impl[T] Box[T] { ... }` +- Calls and literals: `foo[T](...)`, `Type[T]{...}` + +## Rules and constraints + +- Type arguments are mandatory when a type or function has parameters. +- Type parameters cannot take type arguments of their own (`T[U]` is invalid). +- Type arguments are only allowed on calls or struct literals, not bare paths. +- `&T` references cannot be stored in structs or enums and cannot be returned. +- Reference types are only allowed as direct parameter types. + +## Type checking and specialization + +The type checker resolves all type parameters and validates arity at definition +and use sites. Lowering uses the typed HIR, and monomorphization runs before +codegen: + +1) Type check + record fully typed HIR. +2) Monomorphize: create specialized functions/structs/enums per instantiation. +3) Codegen only sees monomorphic types. + +If a generic type escapes into codegen without specialization, it is a compiler +error (generic parameters must be fully resolved before codegen). + +## Implementation notes + +This is implemented as an explicit monomorphization pass, not a polymorphic +backend. + +- **AST/HIR surface**: type parameters are stored on functions, structs, enums, + and impl blocks. Call and struct-literal nodes carry explicit type arguments. +- **Type checker**: `Ty::Param` represents generic parameters; all parameter and + type argument arity is validated in `typeck::collect` and `typeck::check`. + Type substitution is applied for field access, struct literals, and calls. +- **Typed lowering**: lowering reads the recorded type table and reuses the + resolved `Ty` for expression types so HIR is fully typed. +- **Monomorphizer** (`capc/src/typeck/monomorphize.rs`): + - Enqueues function instantiations encountered in calls. + - Specializes structs/enums referenced by concrete type arguments. + - Rewrites call targets and `Ty::Path` nodes to the specialized names. +- **Name mangling**: specialized definitions get a stable, type-based suffix + (e.g., `Box__i32`). +- **Codegen**: assumes monomorphic input; `Ty::Param` is rejected as a hard + error if it reaches codegen. + +## Example + +```cap +module generics_basic +use sys::system + +struct Box[T] { value: T } + +fn id[T](value: T) -> T { return value } + +pub fn main(rc: RootCap) -> i32 { + let b = Box[i32]{ value: 42 } + let v = id[i32](b.value) + rc.mint_console().print_i32(v) + return 0 +} +```