From 502f0b1302212d819e6df04bf77b10e2f993c3a8 Mon Sep 17 00:00:00 2001 From: LunaStev Date: Tue, 6 Jan 2026 20:11:46 +0900 Subject: [PATCH] feat: improve inline asm parsing and type safety Enhance inline assembly block parsing and type handling for greater robustness and expressiveness. Changes: - **ASM Parsing**: - `parse_asm_block` now supports generic `Expression` for inputs/outputs, allowing variable references, literals (decimal, hex, binary), pointers (`&x`), and dereferences (`deref x`). - Added `parse_asm_operand` to handle various operand types. - Implemented `parse_asm_inout_clause` for cleaner `in`/`out` parsing. - Enforced `is_assignable` check for `out` operands. - **LLVM Codegen**: - `gen_asm_stmt_ir` refactored to support multiple outputs via struct return values. - `asm_operand_to_value` converts AST expressions to LLVM values, handling literals (with correct radix) and address-of operations. - `store_asm_output` handles writing assembly results back to variables, including pointer casts. - Added `coerce_basic_value` for explicit type conversions (e.g., int to pointer for syscalls) and implicit widening. - **Type Safety**: - Implemented `coerce_to_expected` in function calls to automatically widen integers (e.g., `i32` -> `i64`) and cast pointers where safe. - `gen_variable_ir` now uses `coerce_basic_value` for initializers. - **Tests**: - Updated `test56.wave` (syscalls) to use new asm features and type-safe wrappers (`syscall4i`, `syscall4p`). - Fixed array size in `test66.wave`. - Added explicit type for overflow test in `test69.wave`. - Marked `test56.wave` as a known timeout in test runner (server test). - **Error Reporting**: - Integrated `colorex` crate for colored error messages in `front/error`. - Lexer now produces `IntLiteral` token type, storing the raw string representation for later parsing (preserving radix info). This update makes inline assembly more versatile and integrates it better with the type system. Signed-off-by: LunaStev --- front/error/Cargo.toml | 2 +- front/error/src/error.rs | 46 +- front/lexer/src/lexer/scan.rs | 8 +- front/lexer/src/token.rs | 2 +- front/parser/src/ast.rs | 8 +- front/parser/src/format.rs | 169 ++----- front/parser/src/parser/asm.rs | 261 +++++----- front/parser/src/parser/decl.rs | 99 ++-- front/parser/src/parser/functions.rs | 4 +- front/parser/src/parser/mod.rs | 18 +- front/parser/src/parser/types.rs | 8 +- .../llvm_temporary/expression/rvalue/asm.rs | 129 +++-- .../llvm_temporary/expression/rvalue/calls.rs | 80 +++- .../expression/rvalue/literals.rs | 89 +++- .../llvm_temporary/expression/rvalue/mod.rs | 26 +- .../src/llvm_temporary/llvm_codegen/consts.rs | 46 +- .../src/llvm_temporary/statement/asm.rs | 214 +++++++-- .../src/llvm_temporary/statement/mod.rs | 12 +- .../src/llvm_temporary/statement/variable.rs | 444 ++++-------------- test/test21.wave | 2 - test/test56.wave | 42 +- test/test66.wave | 2 +- test/test69.wave | 2 +- tools/run_tests.py | 1 + 24 files changed, 872 insertions(+), 842 deletions(-) diff --git a/front/error/Cargo.toml b/front/error/Cargo.toml index 3fb3ef5e..b29f53ee 100644 --- a/front/error/Cargo.toml +++ b/front/error/Cargo.toml @@ -4,4 +4,4 @@ version = "0.1.0" edition = "2021" [dependencies] -colored = "2.0" +colorex = "0.1.1" diff --git a/front/error/src/error.rs b/front/error/src/error.rs index 1c3d4397..51fc0c9d 100644 --- a/front/error/src/error.rs +++ b/front/error/src/error.rs @@ -178,13 +178,13 @@ impl WaveError { /// Display error in Rust-style format pub fn display(&self) { - use colored::*; + use colorex::Colorize; let severity_str = match self.severity { - ErrorSeverity::Error => "error".red().bold(), - ErrorSeverity::Warning => "warning".yellow().bold(), - ErrorSeverity::Note => "note".cyan().bold(), - ErrorSeverity::Help => "help".green().bold(), + ErrorSeverity::Error => "error".color("255,71,71").bold(), + ErrorSeverity::Warning => "warning".color("145,161,2").bold(), + ErrorSeverity::Note => "note".color("0,255,255").bold(), + ErrorSeverity::Help => "help".color("38,139,235").bold(), }; // Main error message @@ -193,52 +193,52 @@ impl WaveError { // Location eprintln!( " {} {}:{}:{}", - "-->".blue().bold(), + "-->".color("38,139,235").bold(), self.file, self.line, self.column ); - eprintln!(" {}", "|".blue().bold()); + eprintln!(" {}", "|".color("38,139,235").bold()); // Source code with highlighting if let Some(source_line) = &self.source { eprintln!( "{:>3} {} {}", - self.line.to_string().blue().bold(), - "|".blue().bold(), + self.line.to_string().color("38,139,235").bold(), + "|".color("38,139,235").bold(), source_line ); // Arrow pointing to the error let spaces = " ".repeat(self.column.saturating_sub(1)); let arrow = match self.severity { - ErrorSeverity::Error => "^".red().bold(), - ErrorSeverity::Warning => "^".yellow().bold(), - ErrorSeverity::Note => "^".cyan().bold(), - ErrorSeverity::Help => "^".green().bold(), + ErrorSeverity::Error => "^".color("255,71,71").bold(), + ErrorSeverity::Warning => "^".color("145,161,2").bold(), + ErrorSeverity::Note => "^".color("0,255,255").bold(), + ErrorSeverity::Help => "^".color("38,139,235").bold(), }; if let Some(label) = &self.label { eprintln!( " {} {}{} {}", - "|".blue().bold(), + "|".color("38,139,235").bold(), spaces, arrow, - label.dimmed() + label.dim() ); } else { - eprintln!(" {} {}{}", "|".blue().bold(), spaces, arrow); + eprintln!(" {} {}{}", "|".color("38,139,235").bold(), spaces, arrow); } } - eprintln!(" {}", "|".blue().bold()); + eprintln!(" {}", "|".color("38,139,235").bold()); // Additional information if let Some(note) = &self.note { eprintln!( " {} {}: {}", - "=".blue().bold(), - "note".cyan().bold(), + "=".color("38,139,235").bold(), + "note".color("0,255,255").bold(), note ); } @@ -246,8 +246,8 @@ impl WaveError { if let Some(help) = &self.help { eprintln!( " {} {}: {}", - "=".blue().bold(), - "help".green().bold(), + "=".color("38,139,235").bold(), + "help".color("38,139,235").bold(), help ); } @@ -256,8 +256,8 @@ impl WaveError { for suggestion in &self.suggestions { eprintln!( " {} {}: {}", - "=".blue().bold(), - "suggestion".green().bold(), + "=".color("38,139,235").bold(), + "suggestion".color("38,139,235").bold(), suggestion ); } diff --git a/front/lexer/src/lexer/scan.rs b/front/lexer/src/lexer/scan.rs index 9e20b36f..e32c6115 100644 --- a/front/lexer/src/lexer/scan.rs +++ b/front/lexer/src/lexer/scan.rs @@ -345,10 +345,10 @@ impl<'a> Lexer<'a> { let value = i64::from_str_radix(&bin_str, 2).unwrap_or(0); return Token { - token_type: TokenType::Number(value), + token_type: TokenType::IntLiteral(format!("0b{}", bin_str)), lexeme: format!("0b{}", bin_str), line: self.line, - } + }; } if c == '0' && (self.peek() == 'x' || self.peek() == 'X') { @@ -362,7 +362,7 @@ impl<'a> Lexer<'a> { let value = i64::from_str_radix(&hex_str, 16).unwrap_or(0); return Token { - token_type: TokenType::Number(value), + token_type: TokenType::IntLiteral(format!("0x{}", hex_str)), lexeme: format!("0x{}", hex_str), line: self.line, }; @@ -387,7 +387,7 @@ impl<'a> Lexer<'a> { let token_type = if is_float { num_str.parse::().map(TokenType::Float).unwrap() } else { - num_str.parse::().map(TokenType::Number).unwrap() + TokenType::IntLiteral(num_str.clone()) }; Token { diff --git a/front/lexer/src/token.rs b/front/lexer/src/token.rs index e7d6e163..960563bc 100644 --- a/front/lexer/src/token.rs +++ b/front/lexer/src/token.rs @@ -124,7 +124,7 @@ pub enum TokenType { TypeArray(Box, u32), Identifier(String), String(String), - Number(i64), + IntLiteral(String), Float(f64), Plus, // + Increment, // ++ diff --git a/front/parser/src/ast.rs b/front/parser/src/ast.rs index 6ccf4867..47a9bcf9 100644 --- a/front/parser/src/ast.rs +++ b/front/parser/src/ast.rs @@ -142,7 +142,7 @@ pub enum Expression { #[derive(Debug, Clone)] pub enum Literal { - Number(i64), + Int(String), Float(f64), String(String), Bool(bool), @@ -226,8 +226,8 @@ pub enum StatementNode { }, AsmBlock { instructions: Vec, - inputs: Vec<(String, String)>, - outputs: Vec<(String, String)>, + inputs: Vec<(String, Expression)>, + outputs: Vec<(String, Expression)>, }, Break, Continue, @@ -281,7 +281,7 @@ impl Expression { .unwrap_or_else(|| panic!("Variable '{}' not found", name)) .ty .clone(), - Expression::Literal(Literal::Number(_)) => WaveType::Int(32), // 기본 int + Expression::Literal(Literal::Int(_)) => WaveType::Int(32), Expression::Literal(Literal::Float(_)) => WaveType::Float(32), Expression::Literal(Literal::String(_)) => WaveType::String, Expression::MethodCall { .. } => { diff --git a/front/parser/src/format.rs b/front/parser/src/format.rs index 6a1e96f5..04383e1b 100644 --- a/front/parser/src/format.rs +++ b/front/parser/src/format.rs @@ -4,6 +4,7 @@ use lexer::Token; use lexer::token::TokenType; use std::iter::Peekable; use std::slice::Iter; +use crate::asm::{parse_asm_inout_clause, parse_asm_operand}; pub fn parse_format_string(s: &str) -> Vec { let mut parts = Vec::new(); @@ -34,7 +35,7 @@ pub fn parse_format_string(s: &str) -> Vec { parts } -fn is_assignable(expr: &Expression) -> bool { +pub fn is_assignable(expr: &Expression) -> bool { match expr { Expression::Variable(_) => true, Expression::Deref(_) => true, @@ -60,7 +61,7 @@ fn desugar_incdec(line: usize, target: Expression, is_inc: bool) -> Option { + let tok = tokens.next()?; // '-' + let inner = parse_unary_expression(tokens)?; + + match inner { + Expression::Literal(Literal::Int(s)) => { + return Some(Expression::Literal(Literal::Int(format!("-{}", s)))); + } + Expression::Literal(Literal::Float(f)) => { + return Some(Expression::Literal(Literal::Float(-f))); + } + _ => { + println!("Error: unary '-' only supports numeric literals (line {})", tok.line); + return None; + } + } + } + + TokenType::Plus => { + tokens.next(); // consume '+' + let inner = parse_unary_expression(tokens)?; + return Some(inner); + } _ => {} } } @@ -390,9 +414,9 @@ where let token = (*tokens.peek()?).clone(); let mut expr = match &token.token_type { - TokenType::Number(value) => { + TokenType::IntLiteral(s) => { tokens.next(); - Some(Expression::Literal(Literal::Number(*value))) + Some(Expression::Literal(Literal::Int(s.clone()))) } TokenType::Float(value) => { tokens.next(); @@ -574,101 +598,25 @@ where break; } - TokenType::In | TokenType::Out => { - let is_input = matches!(token.token_type, TokenType::In); - tokens.next(); - - if tokens.peek().map(|t| t.token_type.clone()) != Some(TokenType::Lparen) { - println!("Expected '(' after in/out"); - return None; - } - tokens.next(); - - let reg_token = tokens.next(); - let reg = match reg_token { - Some(Token { - token_type: TokenType::String(s), - .. - }) => s.clone(), - Some(Token { - token_type: TokenType::Identifier(s), - .. - }) => s.clone(), - Some(other) => { - println!( - "Expected register string or identifier, got {:?}", - other.token_type - ); - return None; - } - None => { - println!("Expected register in in/out(...)"); - return None; - } - }; - - if tokens.peek().map(|t| t.token_type.clone()) != Some(TokenType::Rparen) { - println!("Expected ')' after in/out"); - return None; - } - tokens.next(); - - let value = parse_asm_value(tokens)?; - - let value_expr = parse_expression(tokens)?; - if is_input { - inputs.push((reg, value_expr)); - } else { - outputs.push((reg, value_expr)); - } + TokenType::In => { + tokens.next(); // consume 'in' + parse_asm_inout_clause(tokens, true, &mut inputs, &mut outputs)?; } - TokenType::Identifier(s) if s == "in" || s == "out" => { - let is_input = s == "in"; - tokens.next(); - - if tokens.peek().map(|t| t.token_type.clone()) != Some(TokenType::Lparen) { - println!("Expected '(' after in/out"); - return None; - } - tokens.next(); - - let reg_token = tokens.next(); - let reg = match reg_token { - Some(Token { - token_type: TokenType::String(s), - .. - }) => s.clone(), - Some(Token { - token_type: TokenType::Identifier(s), - .. - }) => s.clone(), - Some(other) => { - println!( - "Expected register string or identifier, got {:?}", - other.token_type - ); - return None; - } - None => { - println!("Expected register in in/out(...)"); - return None; - } - }; + TokenType::Out => { + tokens.next(); // consume 'out' + parse_asm_inout_clause(tokens, false, &mut inputs, &mut outputs)?; + } - if tokens.peek().map(|t| t.token_type.clone()) != Some(TokenType::Rparen) { - println!("Expected ')' after in/out(...)"); - return None; - } - tokens.next(); - let value = parse_asm_value(tokens)?; + TokenType::Identifier(s) if s == "in" => { + tokens.next(); // consume identifier 'in' + parse_asm_inout_clause(tokens, true, &mut inputs, &mut outputs)?; + } - if is_input { - inputs.push((reg, Variable(value))); - } else { - outputs.push((reg, Variable(value))); - } + TokenType::Identifier(s) if s == "out" => { + tokens.next(); // consume identifier 'out' + parse_asm_inout_clause(tokens, false, &mut inputs, &mut outputs)?; } TokenType::String(s) => { @@ -728,7 +676,6 @@ where return None; }; - // 기존 expr(Option)에서 실제 Expression 꺼내기 let base_expr = match expr.take() { Some(e) => e, None => { @@ -737,7 +684,6 @@ where } }; - // 다음 토큰이 '(' 이면 메서드 호출, 아니면 필드 접근 if let Some(Token { token_type: TokenType::Lparen, .. @@ -901,34 +847,3 @@ pub fn parse_expression_from_token( _ => None, } } - -fn parse_asm_value<'a, T>(tokens: &mut Peekable) -> Option -where - T: Iterator, -{ - let token = tokens.next()?; - match &token.token_type { - TokenType::Identifier(s) => Some(s.clone()), - TokenType::Number(n) => Some(n.to_string()), - TokenType::String(s) => Some(s.clone()), - TokenType::AddressOf => { - if let Some(Token { - token_type: TokenType::Identifier(s), - .. - }) = tokens.next() - { - Some(format!("&{}", s)) - } else { - println!("Expected identifier after '&' in in/out(...)"); - None - } - } - other => { - println!( - "Expected identifier or number after in/out(...), got {:?}", - other - ); - None - } - } -} \ No newline at end of file diff --git a/front/parser/src/parser/asm.rs b/front/parser/src/parser/asm.rs index 8e20c7e5..e86a4d5e 100644 --- a/front/parser/src/parser/asm.rs +++ b/front/parser/src/parser/asm.rs @@ -2,140 +2,181 @@ use std::iter::Peekable; use std::slice::Iter; use lexer::Token; use lexer::token::TokenType; -use crate::ast::{ASTNode, StatementNode}; +use crate::ast::{ASTNode, Expression, Literal, StatementNode}; +use crate::format::is_assignable; -pub fn parse_asm_block(tokens: &mut Peekable>) -> Option { +pub fn parse_asm_block(tokens: &mut Peekable>) -> Option { if tokens.peek()?.token_type != TokenType::Lbrace { println!("Expected '{{' after 'asm'"); return None; } - tokens.next(); + tokens.next(); // consume '{' let mut instructions = vec![]; - let mut inputs = vec![]; - let mut outputs = vec![]; + let mut inputs: Vec<(String, Expression)> = vec![]; + let mut outputs: Vec<(String, Expression)> = vec![]; - while let Some(token) = tokens.next() { - match &token.token_type { - TokenType::Rbrace => break, + let mut closed = false; - TokenType::In | TokenType::Out => { - let is_input = matches!(token.token_type, TokenType::In); - - if tokens.next().map(|t| t.token_type.clone()) != Some(TokenType::Lparen) { - println!("Expected '(' after in/out"); - return None; - } - - let reg_token = tokens.next(); - let reg = match reg_token { - Some(Token { - token_type: TokenType::String(s), - .. - }) => s.clone(), - Some(Token { - token_type: TokenType::Identifier(s), - .. - }) => s.clone(), - Some(other) => { - println!( - "Expected register string or identifier, got {:?}", - other.token_type - ); - return None; - } - None => { - println!("Expected register in in/out(...)"); - return None; - } - }; - - if tokens.next().map(|t| t.token_type.clone()) != Some(TokenType::Rparen) { - println!("Expected ')' after in/out"); - return None; - } + while let Some(tok) = tokens.peek() { + match &tok.token_type { + TokenType::Rbrace => { + tokens.next(); // consume '}' + closed = true; + break; + } - let value_token = tokens.next(); - let value = match value_token { - Some(Token { - token_type: TokenType::Minus, - .. - }) => match tokens.next() { - Some(Token { - token_type: TokenType::Number(n), - .. - }) => format!("-{}", n), - Some(other) => { - println!("Expected number after '-', got {:?}", other.token_type); - return None; - } - None => { - println!("Expected number after '-'"); - return None; - } - }, - Some(Token { - token_type: TokenType::AddressOf, - .. - }) => match tokens.next() { - Some(Token { - token_type: TokenType::Identifier(s), - .. - }) => format!("&{}", s), - Some(other) => { - println!("Expected identifier after '&', got {:?}", other.token_type); - return None; - } - None => { - println!("Expected identifier after '&'"); - return None; - } - }, - Some(Token { - token_type: TokenType::Identifier(s), - .. - }) => s.clone(), - Some(Token { - token_type: TokenType::Number(n), - .. - }) => n.to_string(), - Some(Token { - token_type: TokenType::String(n), - .. - }) => n.to_string(), - Some(other) => { - println!( - "Expected identifier or number after in/out(...), got {:?}", - other.token_type - ); - return None; - } - None => { - println!("Expected value after in/out(...)"); - return None; - } - }; - - if is_input { - inputs.push((reg.clone(), value)); - } else { - outputs.push((reg.clone(), value)); - } + TokenType::SemiColon | TokenType::Comma => { + tokens.next(); // 구분자 스킵 } TokenType::String(s) => { instructions.push(s.clone()); + tokens.next(); + } + + TokenType::In => { + tokens.next(); // consume 'in' + parse_asm_inout_clause(tokens, true, &mut inputs, &mut outputs)?; + } + + TokenType::Out => { + tokens.next(); // consume 'out' + parse_asm_inout_clause(tokens, false, &mut inputs, &mut outputs)?; + } + + // 렉서가 키워드로 안 뽑는 경우 대비 + TokenType::Identifier(s) if s == "in" => { + tokens.next(); + parse_asm_inout_clause(tokens, true, &mut inputs, &mut outputs)?; + } + TokenType::Identifier(s) if s == "out" => { + tokens.next(); + parse_asm_inout_clause(tokens, false, &mut inputs, &mut outputs)?; } other => { - println!("Unexpected token in asm expression {:?}", other); + println!("Unexpected token in asm block: {:?}", other); + tokens.next(); } } } + if !closed { + println!("Expected '}}' to close asm block"); + return None; + } + Some(ASTNode::Statement(StatementNode::AsmBlock { instructions, inputs, outputs, })) +} + +pub fn parse_asm_inout_clause<'a, T>( + tokens: &mut Peekable, + is_input: bool, + inputs: &mut Vec<(String, Expression)>, + outputs: &mut Vec<(String, Expression)>, +) -> Option<()> +where + T: Iterator, +{ + if tokens.peek().map(|t| &t.token_type) != Some(&TokenType::Lparen) { + println!("Expected '(' after in/out"); + return None; + } + tokens.next(); // '(' + + let reg = match tokens.next() { + Some(Token { token_type: TokenType::String(s), .. }) => s.clone(), + Some(Token { token_type: TokenType::Identifier(s), .. }) => s.clone(), + Some(other) => { + println!("Expected register string or identifier, got {:?}", other.token_type); + return None; + } + None => { + println!("Expected register in in/out(...)"); + return None; + } + }; + + if tokens.peek().map(|t| &t.token_type) != Some(&TokenType::Rparen) { + println!("Expected ')' after in/out(...)"); + return None; + } + tokens.next(); // ')' + + let value_expr = parse_asm_operand(tokens)?; + + if is_input { + inputs.push((reg, value_expr)); + } else { + if !is_assignable(&value_expr) { + println!("Error: out(...) target must be assignable"); + return None; + } + outputs.push((reg, value_expr)); + } + + Some(()) +} + +pub(crate) fn parse_asm_operand<'a, T>(tokens: &mut Peekable) -> Option +where + T: Iterator, +{ + let tok = tokens.next()?; + match &tok.token_type { + TokenType::Identifier(s) => Some(Expression::Variable(s.clone())), + TokenType::IntLiteral(n) => Some(Expression::Literal(Literal::Int(n.clone()))), + TokenType::String(s) => Some(Expression::Literal(Literal::String(s.clone()))), + + TokenType::AddressOf => { + // &x + let next = tokens.next()?; + match &next.token_type { + TokenType::Identifier(s) => Some(Expression::AddressOf(Box::new(Expression::Variable(s.clone())))), + _ => { + println!("Expected identifier after '&' in in/out(...)"); + None + } + } + } + + TokenType::Deref => { + let next = tokens.next()?; + match &next.token_type { + TokenType::Identifier(s) => Some(Expression::Deref(Box::new(Expression::Variable(s.clone())))), + _ => { + println!("Expected identifier after 'deref' in in/out(...)"); + None + } + } + } + + TokenType::Minus => { + match tokens.next()? { + Token { token_type: TokenType::IntLiteral(n), .. } => { + Some(Expression::Literal(Literal::Int(format!("-{}", n)))) + } + Token { token_type: TokenType::Float(f), .. } => { + Some(Expression::Literal(Literal::Float(-*f))) + } + other => { + println!( + "Expected int/float after '-' in asm operand, got {:?}", + other.token_type + ); + None + } + } + } + + other => { + println!("Expected asm operand, got {:?}", other); + None + } + } } \ No newline at end of file diff --git a/front/parser/src/parser/decl.rs b/front/parser/src/parser/decl.rs index dfead370..bdb85bc4 100644 --- a/front/parser/src/parser/decl.rs +++ b/front/parser/src/parser/decl.rs @@ -6,6 +6,61 @@ use crate::ast::{ASTNode, Expression, Mutability, VariableNode, WaveType}; use crate::format::parse_expression; use crate::parser::types::{parse_type, token_type_to_wave_type}; +fn collect_generic_inner(tokens: &mut Peekable>) -> Option { + let mut inner = String::new(); + let mut depth: i32 = 1; + + while let Some(t) = tokens.next() { + // ✅ 1) 토큰 타입이 chevr면 lexeme가 비어있어도 확실히 처리 + match &t.token_type { + TokenType::Lchevr => { + depth += 1; + inner.push('<'); + continue; + } + TokenType::Rchevr => { + depth -= 1; + if depth == 0 { + return Some(inner); + } + inner.push('>'); + continue; + } + _ => {} + } + + // ✅ 2) 그 외는 문자열(lexeme 또는 Identifier 이름)을 스캔해서 <, > 처리 + let text: &str = if !t.lexeme.is_empty() { + t.lexeme.as_str() + } else if let TokenType::Identifier(name) = &t.token_type { + name.as_str() + } else { + "" + }; + + for ch in text.chars() { + match ch { + '<' => { + depth += 1; + inner.push('<'); + } + '>' => { + depth -= 1; + if depth == 0 { + return Some(inner); + } + inner.push('>'); + } + _ => inner.push(ch), + } + } + } + + println!("Unclosed generic type: missing '>'"); + None +} + + pub fn parse_variable_decl(tokens: &mut Peekable>, is_const: bool) -> Option { let mut mutability = if is_const { Mutability::Const @@ -59,27 +114,7 @@ pub fn parse_variable_decl(tokens: &mut Peekable>, is_const: boo { tokens.next(); // consume '<' - let mut inner = String::new(); - let mut depth = 1; - - while let Some(t) = tokens.next() { - match &t.token_type { - TokenType::Lchevr => { - depth += 1; - inner.push('<'); - } - TokenType::Rchevr => { - depth -= 1; - if depth == 0 { - break; - } else { - inner.push('>'); - } - } - _ => inner.push_str(&t.lexeme), - } - } - + let inner = collect_generic_inner(tokens)?; let full_type_str = format!("{}<{}>", name, inner); let parsed_type = parse_type(&full_type_str); @@ -197,27 +232,7 @@ pub fn parse_var(tokens: &mut Peekable>) -> Option { { tokens.next(); // consume '<' - let mut inner = String::new(); - let mut depth = 1; - - while let Some(t) = tokens.next() { - match &t.token_type { - TokenType::Lchevr => { - depth += 1; - inner.push('<'); - } - TokenType::Rchevr => { - depth -= 1; - if depth == 0 { - break; - } else { - inner.push('>'); - } - } - _ => inner.push_str(&t.lexeme), - } - } - + let inner = collect_generic_inner(tokens)?; let full_type_str = format!("{}<{}>", name, inner); let parsed_type = parse_type(&full_type_str); diff --git a/front/parser/src/parser/functions.rs b/front/parser/src/parser/functions.rs index 94deb814..f22c04dd 100644 --- a/front/parser/src/parser/functions.rs +++ b/front/parser/src/parser/functions.rs @@ -53,9 +53,9 @@ pub fn parse_parameters(tokens: &mut Peekable>) -> Vec>) -> Option' - return Some(WaveType::Array(Box::new(inner_type), size)); + return Some(WaveType::Array(Box::new(inner_type), size.parse().unwrap())); } else { if tokens.peek()?.token_type == TokenType::Rchevr { tokens.next(); // consume '>' diff --git a/llvm_temporary/src/llvm_temporary/expression/rvalue/asm.rs b/llvm_temporary/src/llvm_temporary/expression/rvalue/asm.rs index 642a0cb5..590d05d5 100644 --- a/llvm_temporary/src/llvm_temporary/expression/rvalue/asm.rs +++ b/llvm_temporary/src/llvm_temporary/expression/rvalue/asm.rs @@ -1,7 +1,8 @@ use super::ExprGenEnv; -use inkwell::values::{BasicMetadataValueEnum, CallableValue}; +use inkwell::types::{BasicMetadataTypeEnum, StringRadix}; +use inkwell::values::{BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallableValue}; use inkwell::InlineAsmDialect; -use parser::ast::Expression; +use parser::ast::{Expression, Literal}; use std::collections::HashSet; pub(crate) fn gen<'ctx, 'a>( @@ -9,54 +10,69 @@ pub(crate) fn gen<'ctx, 'a>( instructions: &[String], inputs: &[(String, Expression)], outputs: &[(String, Expression)], -) -> inkwell::values::BasicValueEnum<'ctx> { +) -> BasicValueEnum<'ctx> { let asm_code: String = instructions.join("\n"); - let mut operand_vals: Vec = vec![]; - let mut constraint_parts: Vec = vec![]; + if outputs.len() > 1 { + panic!("asm expression supports at most 1 output for now (got {})", outputs.len()); + } - let input_regs: HashSet<_> = inputs.iter().map(|(reg, _)| reg.to_string()).collect(); + let mut constraint_parts: Vec = vec![]; let mut seen_regs: HashSet = HashSet::new(); - for (reg, var) in outputs { - if input_regs.contains(reg) { - panic!("Register '{}' used in both input and output in inline asm", reg); - } - - if !seen_regs.insert(reg.to_string()) { - panic!("Register '{}' duplicated in outputs", reg); - } - - if let Some(name) = var.as_identifier() { - let info = env - .variables - .get(name) - .unwrap_or_else(|| panic!("Output variable '{}' not found", name)); - let dummy_val = env.builder.build_load(info.ptr, name).unwrap().into(); - operand_vals.push(dummy_val); - constraint_parts.push(format!("={{{}}}", reg)); - } else { - panic!("Unsupported asm output: {:?}", var); + if let Some((out_reg, _out_expr)) = outputs.first() { + if !seen_regs.insert(out_reg.to_string()) { + panic!("Register '{}' duplicated in outputs", out_reg); } + // output constraint + constraint_parts.push(format!("={{{}}}", out_reg)); } + // inputs: operand + constraint + let mut operand_vals: Vec> = Vec::with_capacity(inputs.len()); + for (reg, var) in inputs { if !seen_regs.insert(reg.to_string()) { - panic!("Register '{}' duplicated in inputs", reg); + panic!("Register '{}' duplicated in asm operands", reg); } - let val: BasicMetadataValueEnum = - if let Expression::Literal(parser::ast::Literal::Number(n)) = var { - env.context.i64_type().const_int(*n as u64, true).into() - } else if let Some(name) = var.as_identifier() { - if let Some(info) = env.variables.get(name) { + let val: BasicMetadataValueEnum<'ctx> = match var { + Expression::Literal(Literal::Int(n)) => { + let s = n.as_str(); + + let (neg, digits) = if let Some(rest) = s.strip_prefix('-') { + (true, rest) + } else { + (false, s) + }; + + let ty = env.context.i64_type(); + + let mut iv = ty + .const_int_from_string(digits, StringRadix::Decimal) + .unwrap_or_else(|| panic!("invalid int literal: {}", s)); + + if neg { + iv = iv.const_neg(); + } + + iv.as_basic_value_enum().into() + } + Expression::Literal(Literal::Float(_)) => { + panic!("float literal in asm input not supported yet"); + } + _ => { + if let Some(name) = var.as_identifier() { + let info = env + .variables + .get(name) + .unwrap_or_else(|| panic!("Input variable '{}' not found", name)); env.builder.build_load(info.ptr, name).unwrap().into() } else { - panic!("Input variable '{}' not found", name); + panic!("Unsupported asm input expr: {:?}", var); } - } else { - panic!("Unsupported expression in variable context: {:?}", var); - }; + } + }; operand_vals.push(val); constraint_parts.push(format!("{{{}}}", reg)); @@ -64,36 +80,51 @@ pub(crate) fn gen<'ctx, 'a>( let constraints_str = constraint_parts.join(","); - for (reg, _) in outputs { - constraint_parts.push(format!("={}", reg)) - } - for (reg, _) in inputs { - constraint_parts.push(reg.to_string()); - } + let param_types: Vec> = + operand_vals.iter().map(meta_val_type).collect(); let fn_type = if outputs.is_empty() { - env.context.void_type().fn_type(&[], false) + env.context.void_type().fn_type(¶m_types, false) } else { - env.context.i64_type().fn_type(&[], false) + env.context.i64_type().fn_type(¶m_types, false) }; - let inline_asm_ptr = env.context.create_inline_asm( + let inline_asm = env.context.create_inline_asm( fn_type, asm_code, constraints_str, - true, - false, + true, // sideeffect + false, // alignstack Some(InlineAsmDialect::Intel), false, ); - let inline_asm_fn = - CallableValue::try_from(inline_asm_ptr).expect("Failed to convert inline asm to CallableValue"); + let callable = + CallableValue::try_from(inline_asm).expect("Failed to convert inline asm to CallableValue"); let call = env .builder - .build_call(inline_asm_fn, &operand_vals, "inline_asm_expr") + .build_call(callable, &operand_vals, "inline_asm_expr") .unwrap(); + if outputs.is_empty() { + panic!("asm expression must have an output"); + } + call.try_as_basic_value().left().unwrap() } + +fn meta_val_type<'ctx>(v: &BasicMetadataValueEnum<'ctx>) -> BasicMetadataTypeEnum<'ctx> { + match v { + BasicMetadataValueEnum::IntValue(iv) => iv.get_type().into(), + BasicMetadataValueEnum::FloatValue(fv) => fv.get_type().into(), + BasicMetadataValueEnum::PointerValue(pv) => pv.get_type().into(), + BasicMetadataValueEnum::StructValue(sv) => sv.get_type().into(), + BasicMetadataValueEnum::VectorValue(vv) => vv.get_type().into(), + BasicMetadataValueEnum::ArrayValue(av) => av.get_type().into(), + + BasicMetadataValueEnum::MetadataValue(_) => { + panic!("MetadataValue cannot be used as an inline asm operand"); + } + } +} \ No newline at end of file diff --git a/llvm_temporary/src/llvm_temporary/expression/rvalue/calls.rs b/llvm_temporary/src/llvm_temporary/expression/rvalue/calls.rs index d62a3c54..f570496c 100644 --- a/llvm_temporary/src/llvm_temporary/expression/rvalue/calls.rs +++ b/llvm_temporary/src/llvm_temporary/expression/rvalue/calls.rs @@ -1,6 +1,8 @@ +use inkwell::types::BasicTypeEnum; use super::ExprGenEnv; use inkwell::values::{BasicMetadataValueEnum, BasicValue, BasicValueEnum}; use parser::ast::{Expression, WaveType}; +use crate::llvm_temporary::statement::variable::{coerce_basic_value, CoercionMode}; pub(crate) fn gen_method_call<'ctx, 'a>( env: &mut ExprGenEnv<'ctx, 'a>, @@ -30,7 +32,13 @@ pub(crate) fn gen_method_call<'ctx, 'a>( for (i, arg_expr) in args.iter().enumerate() { let expected_ty = param_types.get(i + 1).cloned(); - let arg_val = env.gen(arg_expr, expected_ty); + let mut arg_val = env.gen(arg_expr, expected_ty); + if let Some(et) = expected_ty { + arg_val = coerce_basic_value( + env.context, env.builder, arg_val, et, &format!("arg{}_cast", i), + CoercionMode::Implicit + ); + } call_args.push(arg_val.into()); } @@ -75,7 +83,13 @@ pub(crate) fn gen_method_call<'ctx, 'a>( for (i, arg_expr) in args.iter().enumerate() { let expected_ty = param_types.get(i + 1).cloned(); - let arg_val = env.gen(arg_expr, expected_ty); + let mut arg_val = env.gen(arg_expr, expected_ty); + if let Some(et) = expected_ty { + arg_val = coerce_basic_value( + env.context, env.builder, arg_val, et, &format!("arg{}_cast", i), + CoercionMode::Implicit + ); + } call_args.push(arg_val.into()); } @@ -119,17 +133,8 @@ pub(crate) fn gen_function_call<'ctx, 'a>( for (i, arg) in args.iter().enumerate() { let expected_param_ty = param_types[i]; - let val = env.gen(arg, Some(expected_param_ty)); - - if val.get_type() != expected_param_ty { - panic!( - "Type mismatch for arg {} of '{}': expected {:?}, got {:?}", - i, - name, - expected_param_ty, - val.get_type() - ); - } + let mut val = env.gen(arg, Some(expected_param_ty)); + val = coerce_to_expected(env, val, expected_param_ty, name, i); call_args.push(val.into()); } @@ -151,3 +156,52 @@ pub(crate) fn gen_function_call<'ctx, 'a>( } } } + +fn coerce_to_expected<'ctx, 'a>( + env: &ExprGenEnv<'ctx, 'a>, + val: BasicValueEnum<'ctx>, + expected: BasicTypeEnum<'ctx>, + name: &str, + arg_index: usize, +) -> BasicValueEnum<'ctx> { + let got = val.get_type(); + if got == expected { + return val; + } + + match (got, expected) { + (BasicTypeEnum::IntType(src), BasicTypeEnum::IntType(dst)) => { + let src_bw = src.get_bit_width(); + let dst_bw = dst.get_bit_width(); + let iv = val.into_int_value(); + + if src_bw < dst_bw { + env.builder + .build_int_s_extend(iv, dst, &format!("arg{}_sext", arg_index)) + .unwrap() + .as_basic_value_enum() + } else if src_bw > dst_bw { + env.builder + .build_int_truncate(iv, dst, &format!("arg{}_trunc", arg_index)) + .unwrap() + .as_basic_value_enum() + } else { + iv.as_basic_value_enum() + } + } + + (BasicTypeEnum::PointerType(_), BasicTypeEnum::PointerType(dst)) => { + env.builder + .build_bit_cast(val, dst, &format!("arg{}_ptrcast", arg_index)) + .unwrap() + .as_basic_value_enum() + } + + _ => { + panic!( + "Type mismatch for arg {} of '{}': expected {:?}, got {:?}", + arg_index, name, expected, got + ); + } + } +} diff --git a/llvm_temporary/src/llvm_temporary/expression/rvalue/literals.rs b/llvm_temporary/src/llvm_temporary/expression/rvalue/literals.rs index 99b5d22f..ee4e073a 100644 --- a/llvm_temporary/src/llvm_temporary/expression/rvalue/literals.rs +++ b/llvm_temporary/src/llvm_temporary/expression/rvalue/literals.rs @@ -1,25 +1,94 @@ use super::ExprGenEnv; -use inkwell::types::{BasicType, BasicTypeEnum}; +use inkwell::types::{BasicType, BasicTypeEnum, StringRadix}; use inkwell::values::{BasicValue, BasicValueEnum}; use parser::ast::Literal; +fn parse_signed_decimal<'a>(s: &'a str) -> (bool, &'a str) { + if let Some(rest) = s.strip_prefix('-') { + (true, rest) + } else { + (false, s) + } +} + +fn is_zero_decimal(s: &str) -> bool { + let s = s.trim(); + let s = s.strip_prefix('+').unwrap_or(s); + let s = s.strip_prefix('-').unwrap_or(s); + + !s.is_empty() && s.chars().all(|c| c == '0') +} + +fn parse_int_radix(s: &str) -> (StringRadix, &str) { + if let Some(rest) = s.strip_prefix("0b").or_else(|| s.strip_prefix("0B")) { + (StringRadix::Binary, rest) + } else if let Some(rest) = s.strip_prefix("0x").or_else(|| s.strip_prefix("0X")) { + (StringRadix::Hexadecimal, rest) + } else if let Some(rest) = s.strip_prefix("0o").or_else(|| s.strip_prefix("0O")) { + (StringRadix::Octal, rest) + } else { + (StringRadix::Decimal, s) + } +} + pub(crate) fn gen<'ctx, 'a>( env: &mut ExprGenEnv<'ctx, 'a>, lit: &Literal, expected_type: Option>, ) -> BasicValueEnum<'ctx> { match lit { - Literal::Number(v) => match expected_type { + Literal::Int(v) => match expected_type { Some(BasicTypeEnum::IntType(int_ty)) => { - int_ty.const_int(*v as u64, false).as_basic_value_enum() + let s = v.as_str(); + let (neg, raw) = parse_signed_decimal(s); + let (radix, digits) = parse_int_radix(raw); + + let mut iv = int_ty + .const_int_from_string(digits, radix) + .unwrap_or_else(|| panic!("invalid int literal: {}", s)); + + if neg { + iv = iv.const_neg(); + } + + iv.as_basic_value_enum() } - None => env - .context - .i64_type() - .const_int(*v as u64, false) - .as_basic_value_enum(), - _ => panic!("Expected integer type for numeric literal, got {:?}", expected_type), - }, + + Some(BasicTypeEnum::PointerType(ptr_ty)) => { + if is_zero_decimal(v.as_str()) { + ptr_ty.const_null().as_basic_value_enum() + } else { + panic!("Only 0 can be used as null pointer literal"); + } + } + + Some(BasicTypeEnum::FloatType(ft)) => { + let f = v + .parse::() + .unwrap_or_else(|_| panic!("invalid float literal from int token: {}", v)); + ft.const_float(f).as_basic_value_enum() + } + + None => { + let s = v.as_str(); + let (neg, raw) = parse_signed_decimal(s); + let (radix, digits) = parse_int_radix(raw); + + let mut iv = env + .context + .i64_type() + .const_int_from_string(digits, radix) + .unwrap_or_else(|| panic!("invalid int literal: {}", s)); + + if neg { + iv = iv.const_neg(); + } + + iv.as_basic_value_enum() + }, + + _ => panic!("Unsupported expected_type for int literal: {:?}", expected_type), + } Literal::Float(value) => match expected_type { Some(BasicTypeEnum::FloatType(float_ty)) => float_ty.const_float(*value).as_basic_value_enum(), diff --git a/llvm_temporary/src/llvm_temporary/expression/rvalue/mod.rs b/llvm_temporary/src/llvm_temporary/expression/rvalue/mod.rs index 73a46c33..84b24129 100644 --- a/llvm_temporary/src/llvm_temporary/expression/rvalue/mod.rs +++ b/llvm_temporary/src/llvm_temporary/expression/rvalue/mod.rs @@ -7,20 +7,20 @@ use inkwell::values::{BasicValueEnum}; use parser::ast::Expression; use std::collections::HashMap; -mod dispatch; -mod utils; +pub mod dispatch; +pub mod utils; -mod literals; -mod variables; -mod pointers; -mod calls; -mod assign; -mod binary; -mod index; -mod asm; -mod structs; -mod unary; -mod incdec; +pub mod literals; +pub mod variables; +pub mod pointers; +pub mod calls; +pub mod assign; +pub mod binary; +pub mod index; +pub mod asm; +pub mod structs; +pub mod unary; +pub mod incdec; pub struct ProtoInfo<'ctx> { pub vtable_ty: StructType<'ctx>, diff --git a/llvm_temporary/src/llvm_temporary/llvm_codegen/consts.rs b/llvm_temporary/src/llvm_temporary/llvm_codegen/consts.rs index 49108c76..e7254c91 100644 --- a/llvm_temporary/src/llvm_temporary/llvm_codegen/consts.rs +++ b/llvm_temporary/src/llvm_temporary/llvm_codegen/consts.rs @@ -1,5 +1,5 @@ use inkwell::context::Context; -use inkwell::types::BasicTypeEnum; +use inkwell::types::{BasicTypeEnum, StringRadix}; use inkwell::values::{BasicValue, BasicValueEnum}; use parser::ast::{Expression, Literal, WaveType}; @@ -7,6 +7,21 @@ use std::collections::HashMap; use super::types::wave_type_to_llvm_type; +fn parse_signed_decimal<'a>(s: &'a str) -> (bool, &'a str) { + if let Some(rest) = s.strip_prefix('-') { + (true, rest) + } else { + (false, s) + } +} + +fn is_zero_decimal(s: &str) -> bool { + let s = s.trim(); + let s = s.strip_prefix('+').unwrap_or(s); + let s = s.strip_prefix('-').unwrap_or(s); + !s.is_empty() && s.chars().all(|c| c == '0') +} + pub(super) fn create_llvm_const_value<'ctx>( context: &'ctx Context, ty: &WaveType, @@ -14,13 +29,36 @@ pub(super) fn create_llvm_const_value<'ctx>( ) -> BasicValueEnum<'ctx> { let struct_types = HashMap::new(); let llvm_type = wave_type_to_llvm_type(context, ty, &struct_types); + match (expr, llvm_type) { - (Expression::Literal(Literal::Number(n)), BasicTypeEnum::IntType(int_ty)) => { - int_ty.const_int(*n as u64, true).as_basic_value_enum() + // new: int literal is string-based + (Expression::Literal(Literal::Int(s)), BasicTypeEnum::IntType(int_ty)) => { + let (neg, digits) = parse_signed_decimal(s.as_str()); + + let mut iv = int_ty + .const_int_from_string(digits, StringRadix::Decimal) + .unwrap_or_else(|| panic!("invalid int literal: {}", s)); + + if neg { + iv = iv.const_neg(); + } + + iv.as_basic_value_enum() } + (Expression::Literal(Literal::Float(f)), BasicTypeEnum::FloatType(float_ty)) => { float_ty.const_float(*f).as_basic_value_enum() } + + // allow const null pointer only via 0 + (Expression::Literal(Literal::Int(s)), BasicTypeEnum::PointerType(ptr_ty)) => { + if is_zero_decimal(s) { + ptr_ty.const_null().as_basic_value_enum() + } else { + panic!("Only 0 can be used as a const null pointer literal"); + } + } + _ => panic!("Constant expression must be a literal of a compatible type."), } -} +} \ No newline at end of file diff --git a/llvm_temporary/src/llvm_temporary/statement/asm.rs b/llvm_temporary/src/llvm_temporary/statement/asm.rs index 3e53fe59..acbf2d78 100644 --- a/llvm_temporary/src/llvm_temporary/statement/asm.rs +++ b/llvm_temporary/src/llvm_temporary/statement/asm.rs @@ -2,16 +2,17 @@ use crate::llvm_temporary::llvm_codegen::VariableInfo; use inkwell::module::Module; use inkwell::values::{BasicMetadataValueEnum, BasicValueEnum, CallableValue}; use inkwell::{AddressSpace, InlineAsmDialect}; -use parser::ast::WaveType; +use parser::ast::{Expression, Literal, WaveType}; use std::collections::{HashMap, HashSet}; +use inkwell::types::{BasicType, StringRadix}; pub(super) fn gen_asm_stmt_ir<'ctx>( context: &'ctx inkwell::context::Context, builder: &'ctx inkwell::builder::Builder<'ctx>, module: &'ctx Module<'ctx>, instructions: &[String], - inputs: &[(String, String)], - outputs: &[(String, String)], + inputs: &[(String, Expression)], + outputs: &[(String, Expression)], variables: &mut HashMap>, global_consts: &HashMap>, ) { @@ -32,48 +33,51 @@ pub(super) fn gen_asm_stmt_ir<'ctx>( constraint_parts.push(format!("={{{}}}", reg)); } - for (reg, var) in inputs { + for (reg, expr) in inputs { if !seen_regs.insert(reg.to_string()) { if reg != "rax" { panic!("Register '{}' duplicated in inputs", reg); } } - let clean_var = if var.starts_with('&') { &var[1..] } else { var.as_str() }; - - let val: BasicMetadataValueEnum = if let Ok(value) = var.parse::() { - context.i64_type().const_int(value as u64, value < 0).into() - } else if let Some(const_val) = global_consts.get(var) { - (*const_val).into() - } else { - let info = variables - .get(clean_var) - .unwrap_or_else(|| panic!("Input variable '{}' not found", clean_var)); - - if var.starts_with('&') { - builder - .build_bit_cast( - info.ptr, - context.i8_type().ptr_type(AddressSpace::from(0)), - "addr_ptr", - ) - .unwrap() - .into() - } else { - builder.build_load(info.ptr, var).unwrap().into() - } - }; - + let val = asm_operand_to_value(context, builder, variables, global_consts, expr); operand_vals.push(val); constraint_parts.push(format!("{{{}}}", reg)); } let constraints_str = constraint_parts.join(","); - let (fn_type, expects_return) = if !outputs.is_empty() { - (context.i64_type().fn_type(&[], false), true) + let (fn_type, out_kinds): (inkwell::types::FunctionType<'ctx>, Vec) = if outputs.is_empty() { + (context.void_type().fn_type(&[], false), vec![]) } else { - (context.void_type().fn_type(&[], false), false) + let mut tys = Vec::new(); + let mut wave_tys = Vec::new(); + + for (_, target_expr) in outputs { + let var_name = asm_output_target(target_expr); + let info = variables + .get(var_name) + .unwrap_or_else(|| panic!("Output variable '{}' not found", var_name)); + + wave_tys.push(info.ty.clone()); + + let bt = match &info.ty { + WaveType::Int(64) => context.i64_type().as_basic_type_enum(), + WaveType::Pointer(inner) => match **inner { + WaveType::Int(8) => context.i8_type().ptr_type(AddressSpace::from(0)).as_basic_type_enum(), + _ => panic!("Unsupported pointer inner type in asm output"), + }, + _ => panic!("Unsupported asm output type: {:?}", info.ty), + }; + tys.push(bt); + } + + if tys.len() == 1 { + (tys[0].fn_type(&[], false), wave_tys) + } else { + let st = context.struct_type(&tys, false); + (st.fn_type(&[], false), wave_tys) + } }; let inline_asm_ptr = context.create_inline_asm( @@ -93,31 +97,137 @@ pub(super) fn gen_asm_stmt_ir<'ctx>( .build_call(inline_asm_fn, &operand_vals, "inline_asm") .unwrap(); - if expects_return { - let ret_val = call.try_as_basic_value().left().unwrap(); - let (_, var) = outputs.iter().next().unwrap(); - let info = variables - .get(var) - .unwrap_or_else(|| panic!("Output variable '{}' not found", var)); + if outputs.is_empty() { + return; + } + + let ret_val = call.try_as_basic_value().left().unwrap(); + + if outputs.len() == 1 { + let (_, target_expr) = &outputs[0]; + let var_name = asm_output_target(target_expr); + let info = variables.get(var_name).unwrap(); + + store_asm_output(context, builder, info, ret_val, var_name); + return; + } + + let struct_val = ret_val.into_struct_value(); + for (idx, (_, target_expr)) in outputs.iter().enumerate() { + let out_elem = builder + .build_extract_value(struct_val, idx as u32, "asm_out") + .unwrap(); + + let var_name = asm_output_target(target_expr); + let info = variables.get(var_name).unwrap(); + + store_asm_output(context, builder, info, out_elem, var_name); + } +} + +fn asm_output_target<'a>(expr: &'a Expression) -> &'a str { + match expr { + Expression::Variable(name) => name.as_str(), + _ => panic!("out(...) target must be a variable for now: {:?}", expr), + } +} + +fn store_asm_output<'ctx>( + context: &'ctx inkwell::context::Context, + builder: &'ctx inkwell::builder::Builder<'ctx>, + info: &VariableInfo<'ctx>, + value: BasicValueEnum<'ctx>, + var_name: &str, +) { + match &info.ty { + WaveType::Int(64) => { + builder.build_store(info.ptr, value).unwrap(); + } + WaveType::Pointer(inner) => match **inner { + WaveType::Int(8) => { + if value.is_pointer_value() { + builder.build_store(info.ptr, value.into_pointer_value()).unwrap(); + return; + } + + let casted_ptr = builder + .build_int_to_ptr( + value.into_int_value(), + context.i8_type().ptr_type(AddressSpace::from(0)), + "casted_ptr", + ) + .unwrap(); + builder.build_store(info.ptr, casted_ptr).unwrap(); + } + _ => panic!("Unsupported pointer inner type in inline asm output"), + }, + _ => panic!("Unsupported return type from inline asm output var '{}': {:?}", var_name, info.ty), + } +} + +fn asm_operand_to_value<'ctx>( + context: &'ctx inkwell::context::Context, + builder: &'ctx inkwell::builder::Builder<'ctx>, + variables: &HashMap>, + global_consts: &HashMap>, + expr: &Expression, +) -> BasicMetadataValueEnum<'ctx> { + match expr { + Expression::Literal(Literal::Int(n)) => { + let s = n.as_str(); + let (neg, digits) = if let Some(rest) = s.strip_prefix('-') { + (true, rest) + } else { + (false, s) + }; + + let mut iv = context + .i64_type() + .const_int_from_string(digits, StringRadix::Decimal) + .unwrap_or_else(|| panic!("invalid int literal: {}", s)); - match &info.ty { - WaveType::Int(64) => { - builder.build_store(info.ptr, ret_val).unwrap(); + if neg { + iv = iv.const_neg(); } - WaveType::Pointer(inner) => match **inner { - WaveType::Int(8) => { - let casted_ptr = builder - .build_int_to_ptr( - ret_val.into_int_value(), + + iv.into() + } + + + Expression::Variable(name) => { + if let Some(const_val) = global_consts.get(name) { + (*const_val).into() + } else { + let info = variables + .get(name) + .unwrap_or_else(|| panic!("Input variable '{}' not found", name)); + builder.build_load(info.ptr, name).unwrap().into() + } + } + + Expression::AddressOf(inner) => { + match inner.as_ref() { + Expression::Variable(name) => { + let info = variables + .get(name) + .unwrap_or_else(|| panic!("Input variable '{}' not found", name)); + builder + .build_bit_cast( + info.ptr, context.i8_type().ptr_type(AddressSpace::from(0)), - "casted_ptr", + "addr_ptr", ) - .unwrap(); - builder.build_store(info.ptr, casted_ptr).unwrap(); + .unwrap() + .into() } - _ => panic!("Unsupported pointer inner type in inline asm output"), - }, - _ => panic!("Unsupported return type from inline asm: {:?}", info.ty), + _ => panic!("Unsupported asm address-of operand: {:?}", inner), + } } + + Expression::Grouped(inner) => { + asm_operand_to_value(context, builder, variables, global_consts, inner) + } + + _ => panic!("Unsupported asm operand expression: {:?}", expr), } } diff --git a/llvm_temporary/src/llvm_temporary/statement/mod.rs b/llvm_temporary/src/llvm_temporary/statement/mod.rs index 4909929c..baabbe8b 100644 --- a/llvm_temporary/src/llvm_temporary/statement/mod.rs +++ b/llvm_temporary/src/llvm_temporary/statement/mod.rs @@ -1,9 +1,9 @@ -mod assign; -mod asm; -mod control; -mod expr_stmt; -mod io; -mod variable; +pub mod assign; +pub mod asm; +pub mod control; +pub mod expr_stmt; +pub mod io; +pub mod variable; use crate::llvm_temporary::llvm_codegen::VariableInfo; use inkwell::basic_block::BasicBlock; diff --git a/llvm_temporary/src/llvm_temporary/statement/variable.rs b/llvm_temporary/src/llvm_temporary/statement/variable.rs index e40f3626..ba36b76f 100644 --- a/llvm_temporary/src/llvm_temporary/statement/variable.rs +++ b/llvm_temporary/src/llvm_temporary/statement/variable.rs @@ -7,6 +7,93 @@ use inkwell::{AddressSpace}; use parser::ast::{Expression, Literal, VariableNode, WaveType}; use std::collections::HashMap; +#[derive(Copy, Clone, Debug)] +pub enum CoercionMode { + Implicit, + Explicit, + Asm, +} + + +pub fn coerce_basic_value<'ctx>( + context: &'ctx inkwell::context::Context, + builder: &'ctx inkwell::builder::Builder<'ctx>, + val: BasicValueEnum<'ctx>, + expected: BasicTypeEnum<'ctx>, + tag: &str, + mode: CoercionMode, +) -> BasicValueEnum<'ctx> { + if val.get_type() == expected { + return val; + } + + match (val, expected) { + // int <-> int + (BasicValueEnum::IntValue(iv), BasicTypeEnum::IntType(dst)) => { + let src_bw = iv.get_type().get_bit_width(); + let dst_bw = dst.get_bit_width(); + + if src_bw == dst_bw { + iv.as_basic_value_enum() + } else if src_bw > dst_bw { + builder.build_int_truncate(iv, dst, tag).unwrap().as_basic_value_enum() + } else { + builder.build_int_s_extend(iv, dst, tag).unwrap().as_basic_value_enum() + } + } + + // float -> int + (BasicValueEnum::FloatValue(fv), BasicTypeEnum::IntType(dst)) => builder + .build_float_to_signed_int(fv, dst, tag) + .unwrap() + .as_basic_value_enum(), + + // int -> float + (BasicValueEnum::IntValue(iv), BasicTypeEnum::FloatType(dst)) => builder + .build_signed_int_to_float(iv, dst, tag) + .unwrap() + .as_basic_value_enum(), + + // ptr -> ptr + (BasicValueEnum::PointerValue(pv), BasicTypeEnum::PointerType(dst)) => builder + .build_bit_cast(pv, dst, tag) + .unwrap() + .as_basic_value_enum(), + + (BasicValueEnum::IntValue(iv), BasicTypeEnum::PointerType(dst)) => { + match mode { + CoercionMode::Implicit => { + if iv.is_const() && iv.get_zero_extended_constant() == Some(0) { + dst.const_null().as_basic_value_enum() + } else { + panic!("Implicit int->ptr is not allowed (use explicit cast)."); + } + } + CoercionMode::Asm | CoercionMode::Explicit => builder + .build_int_to_ptr(iv, dst, tag) + .unwrap() + .as_basic_value_enum(), + } + } + + (BasicValueEnum::PointerValue(pv), BasicTypeEnum::IntType(dst)) => { + match mode { + CoercionMode::Implicit => { + panic!("Implicit ptr->int is not allowed (use explicit cast)."); + } + CoercionMode::Asm | CoercionMode::Explicit => builder + .build_ptr_to_int(pv, dst, tag) + .unwrap() + .as_basic_value_enum(), + } + } + + _ => { + panic!("Type mismatch: expected {:?}, got {:?}", expected, val.get_type()); + } + } +} + pub(super) fn gen_variable_ir<'ctx>( context: &'ctx inkwell::context::Context, builder: &'ctx inkwell::builder::Builder<'ctx>, @@ -90,356 +177,17 @@ pub(super) fn gen_variable_ir<'ctx>( ); if let Some(init) = initial_value { - match (init, llvm_type) { - ( - Expression::Literal(Literal::Number(value)), - BasicTypeEnum::IntType(int_type), - ) => { - let init_value = int_type.const_int(*value as u64, false); - let _ = builder.build_store(alloca, init_value); - } - - ( - Expression::Literal(Literal::Float(value)), - BasicTypeEnum::FloatType(float_type), - ) => { - let init_value = float_type.const_float(*value); - builder.build_store(alloca, init_value).unwrap(); - } - - ( - Expression::Literal(Literal::Bool(v)), - BasicTypeEnum::IntType(int_ty), - ) => { - let val = int_ty.const_int(if *v { 1 } else { 0 }, false); - builder.build_store(alloca, val).unwrap(); - } - - ( - Expression::Literal(Literal::Char(c)), - BasicTypeEnum::IntType(int_ty), - ) => { - let val = int_ty.const_int(*c as u64, false); - builder.build_store(alloca, val).unwrap(); - } - - ( - Expression::Literal(Literal::Byte(b)), - BasicTypeEnum::IntType(int_ty), - ) => { - let val = int_ty.const_int(*b as u64, false); - builder.build_store(alloca, val).unwrap(); - } - - (Expression::Literal(Literal::Float(value)), _) => { - let float_value = context.f32_type().const_float(*value); - - let casted_value = match llvm_type { - BasicTypeEnum::IntType(int_ty) => builder - .build_float_to_signed_int(float_value, int_ty, "float_to_int") - .unwrap() - .as_basic_value_enum(), - BasicTypeEnum::FloatType(_) => float_value.as_basic_value_enum(), - _ => panic!("Unsupported type for float literal initialization"), - }; - - builder.build_store(alloca, casted_value).unwrap(); - } - - ( - Expression::Literal(Literal::String(value)), - BasicTypeEnum::PointerType(_), - ) => unsafe { - let string_name = format!("str_init_{}", name); - let mut bytes = value.as_bytes().to_vec(); - bytes.push(0); - - let const_str = context.const_string(&bytes, false); - let global = module.add_global( - context.i8_type().array_type(bytes.len() as u32), - None, - &string_name, - ); - global.set_initializer(&const_str); - global.set_linkage(Linkage::Private); - global.set_constant(true); - - let zero = context.i32_type().const_zero(); - let indices = [zero, zero]; - let gep = builder - .build_gep(global.as_pointer_value(), &indices, "str_gep") - .unwrap(); - - let _ = builder.build_store(alloca, gep); - }, - - (Expression::AddressOf(inner_expr), BasicTypeEnum::PointerType(_)) => { - match &**inner_expr { - Expression::Variable(var_name) => { - let ptr = variables - .get(var_name) - .unwrap_or_else(|| panic!("Variable {} not found", var_name)); - builder.build_store(alloca, ptr.ptr).unwrap(); - } - - Expression::ArrayLiteral(elements) => { - let elem_type = match llvm_type { - BasicTypeEnum::PointerType(ptr_ty) => match ptr_ty.get_element_type() { - AnyTypeEnum::ArrayType(arr_ty) => arr_ty.get_element_type(), - _ => panic!("Expected pointer to array type"), - }, - _ => panic!("Expected pointer-to-array type for array literal"), - }; - - let array_type = elem_type.array_type(elements.len() as u32); - let tmp_alloca = - builder.build_alloca(array_type, "tmp_array").unwrap(); - - for (i, expr) in elements.iter().enumerate() { - let val = generate_expression_ir( - context, - builder, - expr, - variables, - module, - Some(elem_type), - global_consts, - struct_types, - struct_field_indices, - ); - - let gep = builder - .build_in_bounds_gep( - tmp_alloca, - &[ - context.i32_type().const_zero(), - context.i32_type().const_int(i as u64, false), - ], - &format!("array_idx_{}", i), - ) - .unwrap(); - - builder.build_store(gep, val).unwrap(); - } - - builder.build_store(alloca, tmp_alloca).unwrap(); - } - - _ => panic!("& operator must be used on variable name or array literal"), - } - } - - (Expression::Deref(inner_expr), BasicTypeEnum::IntType(_)) => { - let target_ptr = match &**inner_expr { - Expression::Variable(var_name) => { - let ptr_to_value = variables.get(var_name).unwrap().ptr; - builder - .build_load(ptr_to_value, "load_ptr") - .unwrap() - .into_pointer_value() - } - _ => panic!("Invalid deref in variable init"), - }; - - let val = builder.build_load(target_ptr, "deref_value").unwrap(); - let _ = builder.build_store(alloca, val); - } - - (Expression::IndexAccess { .. }, _) => { - let val = generate_expression_ir( - context, - builder, - init, - variables, - module, - Some(llvm_type), - global_consts, - struct_types, - struct_field_indices, - ); - builder.build_store(alloca, val).unwrap(); - } - - (Expression::FunctionCall { .. } | Expression::MethodCall { .. }, _) => { - let val = generate_expression_ir( - context, - builder, - init, - variables, - module, - Some(llvm_type), - global_consts, - struct_types, - struct_field_indices, - ); - - if val.get_type() != llvm_type { - panic!( - "Initializer type mismatch: expected {:?}, got {:?}", - llvm_type, - val.get_type() - ); - } - - builder.build_store(alloca, val).unwrap(); - } - - (Expression::BinaryExpression { .. }, _) => { - let val = generate_expression_ir( - context, - builder, - init, - variables, - module, - Some(llvm_type), - global_consts, - struct_types, - struct_field_indices, - ); - - let casted_val = match (val, llvm_type) { - (BasicValueEnum::FloatValue(v), BasicTypeEnum::IntType(t)) => builder - .build_float_to_signed_int(v, t, "float_to_int") - .unwrap() - .as_basic_value_enum(), - (BasicValueEnum::IntValue(v), BasicTypeEnum::FloatType(t)) => builder - .build_signed_int_to_float(v, t, "int_to_float") - .unwrap() - .as_basic_value_enum(), - _ => val, - }; - - builder.build_store(alloca, casted_val).unwrap(); - } - - (Expression::Variable(var_name), _) => { - let source_var = variables - .get(var_name) - .unwrap_or_else(|| panic!("Variable {} not found", var_name)); - - let loaded_value = builder - .build_load(source_var.ptr, &format!("load_{}", var_name)) - .unwrap(); - - let loaded_type = loaded_value.get_type(); - - let casted_value = match (loaded_type, llvm_type) { - (BasicTypeEnum::IntType(_), BasicTypeEnum::FloatType(float_ty)) => builder - .build_signed_int_to_float( - loaded_value.into_int_value(), - float_ty, - "int_to_float", - ) - .unwrap() - .as_basic_value_enum(), - (BasicTypeEnum::FloatType(_), BasicTypeEnum::IntType(int_ty)) => builder - .build_float_to_signed_int( - loaded_value.into_float_value(), - int_ty, - "float_to_int", - ) - .unwrap() - .as_basic_value_enum(), - _ => loaded_value, - }; - - builder.build_store(alloca, casted_value).unwrap(); - } - - ( - Expression::AsmBlock { - instructions, - inputs, - outputs, - }, - BasicTypeEnum::IntType(_), - ) => { - use inkwell::values::{BasicMetadataValueEnum, CallableValue}; - use inkwell::InlineAsmDialect; - - let asm_code: String = instructions.join("\n"); - let mut operand_vals: Vec = vec![]; - let mut constraint_parts = vec![]; - - for (reg, var) in inputs { - let val = if let Expression::Literal(Literal::Number(n)) = var { - context.i64_type().const_int(*n as u64, true).into() - } else if let Some(name) = var.as_identifier() { - if let Some(info) = variables.get(name) { - builder.build_load(info.ptr, name).unwrap().into() - } else { - panic!("Variable '{}' not found", name); - } - } else { - panic!("Unsupported expression in statement: {:?}", var); - }; - - operand_vals.push(val); - constraint_parts.push(format!("{{{}}}", reg)); - } - - for (reg, _) in outputs { - constraint_parts.insert(0, format!("={{{}}}", reg)); - } - - let constraint_str = constraint_parts.join(","); - - let (fn_type, expects_return) = if outputs.is_empty() { - (context.void_type().fn_type(&[], false), false) - } else { - (context.i64_type().fn_type(&[], false), true) - }; - - let inline_asm_ptr = context.create_inline_asm( - fn_type, - asm_code, - constraint_str, - true, - false, - Some(InlineAsmDialect::Intel), - false, - ); - - let inline_asm_fn = CallableValue::try_from(inline_asm_ptr) - .expect("Failed to cast inline asm to CallableValue"); - - let call = builder - .build_call(inline_asm_fn, &operand_vals, "inline_asm") - .unwrap(); - - if expects_return { - let result = call - .try_as_basic_value() - .left() - .expect("Expected return value from inline asm but got none"); - - builder.build_store(alloca, result).unwrap(); - } - } - - (init_expr @ Expression::StructLiteral { .. }, _) => { - let val = generate_expression_ir( - context, - builder, - init_expr, - variables, - module, - Some(llvm_type), - global_consts, - struct_types, - struct_field_indices, - ); + let raw = generate_expression_ir( + context, builder, init, variables, module, + Some(llvm_type), + global_consts, struct_types, struct_field_indices, + ); - builder.build_store(alloca, val).unwrap(); - } + let casted = coerce_basic_value( + context, builder, raw, llvm_type, "init_cast", CoercionMode::Implicit + ); - _ => { - panic!( - "Unsupported type/value combination for initialization: {:?}", - init - ); - } - } + builder.build_store(alloca, casted).unwrap(); } } } diff --git a/test/test21.wave b/test/test21.wave index a1dc022e..20342c5e 100644 --- a/test/test21.wave +++ b/test/test21.wave @@ -1,5 +1,3 @@ -import("std::iosys"); - fun main() { println("{}", ((0b110111010) & 0x1FF)); } \ No newline at end of file diff --git a/test/test56.wave b/test/test56.wave index 3fc10d5b..0f3effa7 100644 --- a/test/test56.wave +++ b/test/test56.wave @@ -31,24 +31,38 @@ fun syscall3(id: i64, arg1: i64, arg2: i64) -> i64 { return ret_val; } -fun syscall4(id: i64, arg1: i64, arg2: i64, arg3: i64) -> i64 { - var ret_val: i64; +fun syscall4i(id: i64, a1: i64, a2: i64, a3: i64) -> i64 { + var ret: i64; asm { "syscall" in("rax") id - in("rdi") arg1 - in("rsi") arg2 - in("rdx") arg3 - out("rax") ret_val + in("rdi") a1 + in("rsi") a2 + in("rdx") a3 + out("rax") ret } - return ret_val; + return ret; +} + +fun syscall4p(id: i64, a1: i64, a2: ptr, a3: i64) -> i64 { + var ret: i64; + asm { + "syscall" + in("rax") id + in("rdi") a1 + in("rsi") a2 + in("rdx") a3 + out("rax") ret + } + return ret; } fun _socket_create() -> i32 { - return syscall4(41, 2, 1, 0); + return syscall3(41, 2, 1); } fun _socket_bind(sockfd: i32, ip_addr: i32, port: i16) -> i32 { + var result: i64; asm { out("rax") result in("rdi") sockfd @@ -64,17 +78,13 @@ fun _socket_close(sockfd: i32) { syscall2(3, sockfd); } -fun new_server(ip_str: str, port: i16) -> ptr { +fun new_server(ip_str: str, port: i16) -> i32 { var sockfd: i32 = _socket_create(); - if (sockfd < 0) { - println("Error: Failed to create socket."); - return 0; - } + if (sockfd < 0) { return -1; } if (_socket_bind(sockfd, 0, port) < 0) { - println("Error: Failed to bind socket."); _socket_close(sockfd); - return 0; + return -1; } return sockfd; @@ -104,7 +114,7 @@ fun start(server_fd: i32) { var response: str = "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nWelcome to the Wave HTTP Server!"; - syscall4(1, client_fd, response, 82); + syscall4p(1, client_fd, response, 82); _socket_close(client_fd); println("Client disconnected."); diff --git a/test/test66.wave b/test/test66.wave index f7f5af03..28b12660 100644 --- a/test/test66.wave +++ b/test/test66.wave @@ -50,7 +50,7 @@ fun test_nested_struct() { struct Inventory { - items: array; + items: array; } proto Inventory { diff --git a/test/test69.wave b/test/test69.wave index fb5d89e8..ff58f9d5 100644 --- a/test/test69.wave +++ b/test/test69.wave @@ -9,6 +9,6 @@ fun main() { println("before overflow: {}", max); - var overflow = max + 1; + var overflow: i32 = max + 1; println("after overflow: {}", overflow); } diff --git a/tools/run_tests.py b/tools/run_tests.py index 2da710f6..9de6bb60 100644 --- a/tools/run_tests.py +++ b/tools/run_tests.py @@ -21,6 +21,7 @@ KNOWN_TIMEOUT = { # "test22.wave", + "test56.wave", } FAIL_PATTERNS = [