diff --git a/Cargo.toml b/Cargo.toml index 59f89ad..fa9e64b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kronec" -version = "0.1.0" +version = "0.0.1" edition = "2021" [dependencies] diff --git a/src/ast/expr.rs b/src/ast/expr.rs index e031f83..1ea672e 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -4,10 +4,15 @@ use crate::ast::*; pub enum Expr { BinaryExpression(Box, BinaryOperator, Box), Identifier(Identifier), + Call(Box), + // Literals + BooleanLiteral(bool), IntegerLiteral(i64), FloatLiteral(f64), StringLiteral(String), - Call(Box), + Block(Box), + /// Last field is either Expr::Block or Expr::IfExpr + IfExpr(Box, Box, Box), } #[derive(Debug, PartialEq, Clone)] @@ -16,4 +21,7 @@ pub enum BinaryOperator { Sub, Mul, Div, + Modulo, + Equal, + NotEqual, } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 3e63c19..7592f59 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,17 +1,29 @@ pub mod expr; -pub mod typ; +pub mod module; + +use std::path::Path; pub use crate::ast::expr::{BinaryOperator, Expr}; -pub use crate::ast::typ::*; +use crate::ast::module::*; +use crate::typing::Type; + +pub type Identifier = String; // XXX: Is this enum actually useful? Is 3:30 AM btw #[derive(Debug, PartialEq)] pub enum Ast { + Module(Module), +} + +#[derive(Debug, PartialEq)] +pub enum Definition { FunctionDefinition(FunctionDefinition), - Expr(Expr), - Module(Vec), - Block(Block), - Statement(Statement), + //StructDefinition(StructDefinition), +} + +#[derive(Debug, PartialEq)] +pub struct Location { + pub file: Box, } #[derive(Debug, PartialEq)] @@ -20,6 +32,7 @@ pub struct FunctionDefinition { pub parameters: Vec, pub return_type: Option, pub body: Box, + pub line_col: (usize, usize), } #[derive(Debug, PartialEq)] @@ -30,9 +43,13 @@ pub struct Block { #[derive(Debug, PartialEq)] pub enum Statement { + DeclareStatement(Identifier, Expr), AssignStatement(Identifier, Expr), ReturnStatement(Option), CallStatement(Call), + UseStatement(ModulePath), + IfStatement(Expr, Block), + WhileStatement(Box, Box), } #[derive(Debug, PartialEq)] @@ -41,31 +58,9 @@ pub struct Call { pub args: Vec, } -pub type Identifier = String; - #[derive(Debug, PartialEq)] pub struct Parameter { pub name: Identifier, pub typ: Type, } -impl Ast { - /// Type checks the AST and add missing return types. - pub fn check_return_types(&mut self) -> Result<(), TypeError> { - match self { - Ast::Module(defs) => { - for def in defs { - if let Ast::FunctionDefinition { .. } = def { - def.check_return_types()?; - } - } - } - Ast::FunctionDefinition(func) => { - let typ = func.typ(&mut TypeContext::default())?; - func.return_type = Some(typ.clone()); - } - _ => unreachable!(), - } - Ok(()) - } -} diff --git a/src/ast/module.rs b/src/ast/module.rs new file mode 100644 index 0000000..6159c21 --- /dev/null +++ b/src/ast/module.rs @@ -0,0 +1,66 @@ +use std::path::Path; +use super::Definition; + +#[derive(Debug, PartialEq, Clone)] +pub struct ModulePath { + components: Vec, +} + +impl std::fmt::Display for ModulePath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.components.join("::"))) + } +} + +impl From<&Path> for ModulePath { + fn from(path: &Path) -> Self { + let meta = std::fs::metadata(path).unwrap(); + ModulePath { + components: path + .components() + .map(|component| match component { + std::path::Component::Normal(n) => { + if meta.is_file() { + n.to_str().unwrap().split(".").nth(0).unwrap().to_string() + } else if meta.is_dir() { + n.to_str().unwrap().to_string() + } else { + // XXX: symlinks? + unreachable!() + } + } + _ => unreachable!(), + }) + .collect(), + } + } +} + +impl From<&str> for ModulePath { + fn from(string: &str) -> Self { + ModulePath { + components: string.split("::").map(|c| c.to_string()).collect(), + } + } +} + +type ImportPath = ModulePath; + +#[derive(Debug, PartialEq)] +pub struct Module { + pub file: Option, + pub path: ModulePath, + pub definitions: Vec, + pub imports: Vec, +} + +impl Module { + pub fn new(path: ModulePath) -> Self { + Module { + path, + file: None, + definitions: vec![], + imports: vec![], + } + } +} diff --git a/src/ast/typ.rs b/src/ast/typ.rs deleted file mode 100644 index 5f102a1..0000000 --- a/src/ast/typ.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::collections::HashMap; - -use crate::ast::*; - -#[derive(Debug, PartialEq, Clone)] -pub enum Type { - Int, - Float, - Unit, - Str, - Custom(Identifier), -} - -impl From<&str> for Type { - fn from(value: &str) -> Self { - match value { - "int" => Type::Int, - "float" => Type::Float, - _ => Type::Custom(Identifier::from(value)), - } - } -} - -#[derive(Debug)] -pub enum TypeError { - InvalidBinaryOperator { - operator: BinaryOperator, - lht: Type, - rht: Type, - }, - BlockTypeDoesNotMatchFunctionType { - function_name: String, - function_type: Type, - block_type: Type, - }, - ReturnTypeDoesNotMatchFunctionType { - function_name: String, - function_type: Type, - ret_type: Type, - }, - UnknownIdentifier { - identifier: String, - }, -} - -#[derive(Default)] -pub struct TypeContext { - pub function: Option, - pub variables: HashMap, -} - -/// Trait for nodes which have a deducible type. -pub trait Typ { - /// Try to resolve the type of the node. - fn typ(&self, ctx: &mut TypeContext) -> Result; -} - -impl Typ for FunctionDefinition { - fn typ(&self, ctx: &mut TypeContext) -> Result { - let func = self; - - let mut ctx = TypeContext { - function: Some(func.name.clone()), - ..Default::default() - }; - for param in &func.parameters { - ctx.variables.insert(param.name.clone(), param.typ.clone()); - } - - let body_type = &func.body.typ(&mut ctx)?; - - // If the return type is not specified, it is unit. - let func_return_type = match &func.return_type { - Some(typ) => typ, - None => &Type::Unit, - }; - - // Check coherence with the body's type. - if *func_return_type != *body_type { - return Err(TypeError::BlockTypeDoesNotMatchFunctionType { - function_name: func.name.clone(), - function_type: func_return_type.clone(), - block_type: body_type.clone(), - }) - } - - // Check coherence with return statements. - for statement in &func.body.statements { - if let Statement::ReturnStatement(value) = statement { - let ret_type = match value { - Some(expr) => expr.typ(&mut ctx)?, - None => Type::Unit, - }; - if ret_type != *func_return_type { - return Err(TypeError::ReturnTypeDoesNotMatchFunctionType { - function_name: func.name.clone(), - function_type: func_return_type.clone(), - ret_type, - }) - } - } - } - - Ok(func_return_type.clone()) - } -} - -impl Typ for Block { - fn typ(&self, ctx: &mut TypeContext) -> Result { - // Check if there is an expression at the end of the block. - if let Some(expr) = &self.value { - expr.typ(ctx) - } else { - Ok(Type::Unit) - } - } -} - -impl Typ for Expr { - fn typ(&self, ctx: &mut TypeContext) -> Result { - match self { - Expr::Identifier(identifier) => { - if let Some(typ) = ctx.variables.get(identifier) { - Ok(typ.clone()) - } else { - Err(TypeError::UnknownIdentifier { - identifier: identifier.clone(), - }) - } - } - Expr::IntegerLiteral(_) => Ok(Type::Int), - Expr::FloatLiteral(_) => Ok(Type::Float), - Expr::BinaryExpression(lhs, op, rhs) => match op { - BinaryOperator::Add - | BinaryOperator::Sub - | BinaryOperator::Mul - | BinaryOperator::Div => { - let left_type = &lhs.typ(ctx)?; - let right_type = &rhs.typ(ctx)?; - match (left_type, right_type) { - (Type::Int, Type::Int) => Ok(Type::Int), - (Type::Float, Type::Int | Type::Float) => Ok(Type::Float), - (Type::Int, Type::Float) => Ok(Type::Float), - (_, _) => Err(TypeError::InvalidBinaryOperator { - operator: op.clone(), - lht: left_type.clone(), - rht: right_type.clone(), - }), - } - } - }, - Expr::StringLiteral(_) => Ok(Type::Str), - Expr::Call(call) => { - todo!("resolve call type using ctx"); - } - } - } -} diff --git a/src/main.rs b/src/main.rs index 6820f10..2a8e0da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,10 @@ mod ast; mod parsing; +mod typing; use clap::{Parser, Subcommand}; -use std::fs; + +use crate::ast::module::Module; /// Experimental compiler for krone #[derive(Parser, Debug)] @@ -16,8 +18,8 @@ struct Cli { #[derive(Subcommand, Debug)] enum Commands { Parse { - /// Path to the source file - file: String, + /// Path to the source files + files: Vec, /// Dump the AST to stdout #[arg(long)] @@ -25,7 +27,7 @@ enum Commands { /// Add missing return types in the AST #[arg(long)] - complete_ast: bool, + type_check: bool, }, } @@ -34,25 +36,31 @@ fn main() { match &cli.command { Commands::Parse { - file, + files, dump_ast, - complete_ast, + type_check, } => { - let source = fs::read_to_string(&file).expect("could not read the source file"); - let mut ast = match parsing::parse(&source) { - Ok(ast) => ast, - Err(e) => panic!("Parsing error: {:#?}", e), - }; + let paths = files.iter().map(std::path::Path::new); + let modules: Vec = paths + .map(|path| match parsing::parse_file(&path) { + Ok(module) => module, + Err(e) => panic!("Parsing error: {:#?}", e), + }) + .collect(); - if *complete_ast { - if let Err(e) = ast.check_return_types() { - eprintln!("{:#?}", e); - return; + if *type_check { + for module in &modules { + if let Err(e) = module.type_check() { + eprintln!("{}", e); + return; + } } } if *dump_ast { - println!("{:#?}", &ast); + for module in &modules { + println!("{:#?}", &module); + } return; } diff --git a/src/parsing/backend/handmade/lex.rs b/src/parsing/backend/handmade/lex.rs new file mode 100644 index 0000000..b3e00e4 --- /dev/null +++ b/src/parsing/backend/handmade/lex.rs @@ -0,0 +1,157 @@ +use std::iter::Peekable; +use std::str::Chars; + +#[derive(Debug)] +pub enum Token { + LeftBracket, + RightBracket, + If, + Else, + Identifier(String), + LeftParenthesis, + RightParenthesis, + Func, + Colon, + While, + Set, + LineComment, + Mul, + Sub, + Add, + Slash, + Modulo, + NotEqual, + Equal, + DoubleEquals, + Exclamation, + NumberLiteral, +} + +#[derive(Debug)] +pub enum TokenError { + InvalidToken, +} + +pub struct Lexer { + line: usize, + column: usize, +} + +impl Lexer { + pub fn new() -> Self { + Self { line: 1, column: 1 } + } + + pub fn tokenize(&mut self, input: String) -> Result, TokenError> { + let mut tokens: Vec = Vec::new(); + let mut chars = input.chars().peekable(); + while let Some(tok_or_err) = self.get_next_token(&mut chars) { + match tok_or_err { + Ok(token) => tokens.push(token), + Err(err) => return Err(err), + }; + } + Ok(tokens) + } + + fn get_next_token(&mut self, chars: &mut Peekable) -> Option> { + if let Some(ch) = chars.next() { + let tok_or_err = match ch { + '(' => Ok(Token::LeftParenthesis), + ')' => Ok(Token::RightParenthesis), + '{' => Ok(Token::LeftBracket), + '}' => Ok(Token::RightBracket), + '+' => Ok(Token::Add), + '-' => Ok(Token::Sub), + '*' => Ok(Token::Mul), + '%' => Ok(Token::Modulo), + '/' => { + if let Some('/') = chars.peek() { + chars.next(); + let comment = chars.take_while(|c| c != &'\n'); + self.column += comment.count() + 1; + Ok(Token::LineComment) + } else { + Ok(Token::Slash) + } + } + '=' => { + if let Some(ch2) = chars.peek() { + match ch2 { + '=' => { + chars.next(); + self.column += 1; + Ok(Token::DoubleEquals) + } + ' ' => Ok(Token::Equal), + _ => Err(TokenError::InvalidToken), + } + } else { + Ok(Token::Equal) + } + } + '!' => { + if let Some(ch2) = chars.next() { + match ch2 { + '=' => { + self.column += 1; + Ok(Token::NotEqual) + } + _ => Err(TokenError::InvalidToken), + } + } else { + Ok(Token::Exclamation) + } + } + 'a'..='z' | 'A'..='Z' => { + let mut word = String::from(ch); + while let Some(ch2) = chars.peek() { + if ch2.is_alphanumeric() { + if let Some(ch2) = chars.next() { + word.push(ch2); + } + } else { + break; + } + } + self.column += word.len(); + match word.as_str() { + "func" => Ok(Token::Func), + "if" => Ok(Token::If), + "else" => Ok(Token::Else), + "set" => Ok(Token::Set), + "while" => Ok(Token::While), + _ => Ok(Token::Identifier(word)), + } + } + '0'..='9' | '.' => { + let word = chars + .take_while(|c| c.is_numeric() || c == &'.') + .collect::(); + self.column += word.len(); + // XXX: handle syntax error in number literals + Ok(Token::NumberLiteral) + } + ':' => Ok(Token::Colon), + '\n' => { + self.line += 1; + self.column = 1; + return self.get_next_token(chars); + } + ' ' => { + self.column += 1; + return self.get_next_token(chars); + } + '\t' => { + self.column += 8; + return self.get_next_token(chars); + } + _ => Err(TokenError::InvalidToken), + }; + self.column += 1; + Some(tok_or_err) + } else { + None + } + } +} diff --git a/src/parsing/backend/handmade/mod.rs b/src/parsing/backend/handmade/mod.rs new file mode 100644 index 0000000..0233a21 --- /dev/null +++ b/src/parsing/backend/handmade/mod.rs @@ -0,0 +1,149 @@ +// In progress parser from scratch + +use crate::lex::Token; +use std::cell::RefCell; +use std::rc::Rc; + +#[derive(Debug)] +pub enum NodeType { + Document, // This is the root node + LineComment, + FunctionDefinition, + FunctionParam, + VariableName(String), + Type(String), +} + +use NodeType::*; + +pub struct Node { + kind: NodeType, + parent: Option>>, + children: Vec>, +} + +impl Node { + fn new() -> Self { + Node::default() + } + + fn with_kind(&mut self, kind: NodeType) -> &mut Self { + self.kind = kind; + self + } + + fn with_children(&mut self, children: Vec) -> &mut Self { + for child in children { + self.push_child(child); + } + self + } + + fn push_child(&mut self, mut child: Node) { + child.parent = Some(Rc::new(RefCell::new(*self))); + self.children.push(Box::new(child)); + } + + pub fn print_tree(&self) { + self.print_tree_rec(0); + } + + fn print_tree_rec(&self, indent: u8) { + for _ in 1..=indent { + print!(" "); + } + println!("{:?}", self.kind); + for child in &self.children { + child.print_tree_rec(indent + 2); + } + } +} + +impl Default for Node { + fn default() -> Self { + Node { + kind: Document, + parent: None, + children: Vec::new(), + } + } +} + +impl From for Node { + fn from(value: NodeType) -> Self { + Node { + kind: value, + ..Node::default() + } + } +} + +#[derive(Debug)] +pub enum SyntaxError { + FuncExpectedIdentifier, + FuncExpectedLeftParenthesisAfterIdentifier, + UnexpectedToken, +} + +pub struct Parser {} + +impl Parser { + pub fn new() -> Self { + Parser {} + } + + pub fn parse_tokens(&mut self, tokens: Vec) -> Result { + let mut tokens = tokens.iter().peekable(); + let mut root_node = Node::new(); + + while let Some(token) = tokens.next() { + let node_or_err = match token { + Token::LineComment => Ok(Node { + kind: LineComment, + ..Node::default() + }), + + Token::Func => { + let identifier = if let Some(ident) = tokens.next() { + match ident { + Token::Identifier(id) => Some(id), + _ => return Err(SyntaxError::FuncExpectedIdentifier), + } + } else { + None + }; + + if let Some(Token::LeftParenthesis) = tokens.next() { + } else { + return Err(SyntaxError::FuncExpectedLeftParenthesisAfterIdentifier); + }; + + let mut params: Vec = Vec::new(); + while let Some(Token::Identifier(_)) = tokens.peek() { + if let Some(Token::Identifier(param_name)) = tokens.next() { + if let Some(Token::Colon) = tokens.next() { + if let Some(Token::Identifier(type_name)) = tokens.next() { + let mut node = + Node::new().with_kind(FunctionParam).with_children(vec![ + VariableName(param_name.into()).into(), + Type(type_name.into()).into(), + ]); + params.push(*node); + } + } + } + } + let node = Node::from(NodeType::FunctionDefinition).with_children(params); + Ok(*node) + } + + _ => Err(SyntaxError::UnexpectedToken), + }; + if let Ok(node) = node_or_err { + root_node.push_child(node); + } else { + }; + } + Ok(root_node) + } +} diff --git a/src/parsing/backend/mod.rs b/src/parsing/backend/mod.rs new file mode 100644 index 0000000..b63ce4a --- /dev/null +++ b/src/parsing/backend/mod.rs @@ -0,0 +1 @@ +pub mod pest; diff --git a/src/parsing/backend/pest/grammar.pest b/src/parsing/backend/pest/grammar.pest new file mode 100644 index 0000000..1a6aeec --- /dev/null +++ b/src/parsing/backend/pest/grammar.pest @@ -0,0 +1,70 @@ +// This file is just a little test of pest.rs + +source_file = { SOI ~ module_items ~ EOI } +module_items = { (use_statement | definition)* } + +// Statements +statement = { assign_statement | declare_statement | return_statement | call_statement | use_statement | while_statement | if_statement } +declare_statement = { ident ~ "=" ~ expr ~ ";" } +assign_statement = { "set" ~ ident ~ "=" ~ expr ~ ";" } +return_statement = { "return" ~ expr? ~ ";" } +call_statement = { call ~ ";" } +use_statement = { "use" ~ import_path ~ ";" } +while_statement = { "while" ~ expr ~ block ~ ";" } +if_statement = { if_branch ~ ("else" ~ (if_branch | block))? ~ ";" } + +if_branch = _{ "if" ~ expr ~ block } + +// Module paths +import_path = { ident ~ ("::" ~ ident)* } + +// Function call +call = { ident ~ "(" ~ args ~ ")" } +args = { (expr ~ ",")* ~ expr? } + +definition = { func_def } + +// Function definition +func_def = { "fn" ~ ident ~ "(" ~ parameters ~ ")" ~ typ? ~ block } +parameters = { + (parameter ~ ",")* ~ (parameter)? +} +parameter = { ident ~ ":" ~ typ } + +// Operators +infix = _{ add | subtract | multiply | divide | not_equal | equal | modulo } +add = { "+" } +subtract = { "-" } +multiply = { "*" } +divide = { "/" } +modulo = { "%" } +equal = { "==" } +not_equal = { "!=" } + +prefix = _{ not } +not = { "!" } + +// Expressions +expr = { prefix? ~ atom ~ (infix ~ prefix? ~ atom)* } +atom = _{ call | if_expr | block | literal | ident | "(" ~ expr ~ ")" } +block = { "{" ~ statement* ~ expr? ~ "}" } +if_expr = { "if" ~ expr ~ block ~ "else" ~ (block | if_expr) } + +ident = @{ (ASCII_ALPHANUMERIC | "_")+ } +typ = _{ ident } + +// Literals +literal = _{ boolean_literal | float_literal | integer_literal | string_literal } +boolean_literal = @{ "true" | "false" } +string_literal = ${ "\"" ~ string_content ~ "\"" } +string_content = @{ char* } +char = { + !("\"" | "\\") ~ ANY + | "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t") + | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) +} +integer_literal = @{ ASCII_DIGIT+ } +float_literal = @{ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) ~ "." ~ ASCII_DIGIT* } + +WHITESPACE = _{ " " | "\n" | "\t" } +COMMENT = _{ "//" ~ (!NEWLINE ~ ANY)* } diff --git a/src/parsing/pest.rs b/src/parsing/backend/pest/mod.rs similarity index 57% rename from src/parsing/pest.rs rename to src/parsing/backend/pest/mod.rs index e56eb78..b7d7237 100644 --- a/src/parsing/pest.rs +++ b/src/parsing/backend/pest/mod.rs @@ -1,15 +1,20 @@ -use lazy_static; +use std::fs; +use std::path::Path; + use pest::error::Error; use pest::iterators::Pair; use pest::pratt_parser::PrattParser; use pest::Parser; +use crate::ast::module::{Module, ModulePath}; use crate::ast::*; +use crate::typing::Type; #[derive(pest_derive::Parser)] -#[grammar = "parsing/grammar.pest"] +#[grammar = "parsing/backend/pest/grammar.pest"] struct KrParser; +use lazy_static; lazy_static::lazy_static! { static ref PRATT_PARSER: PrattParser = { use pest::pratt_parser::{Assoc::*, Op}; @@ -17,36 +22,51 @@ lazy_static::lazy_static! { // Precedence is defined lowest to highest PrattParser::new() - // Addition and subtract have equal precedence + .op(Op::infix(equal, Left) | Op::infix(not_equal, Left)) .op(Op::infix(add, Left) | Op::infix(subtract, Left)) + .op(Op::infix(modulo, Left)) .op(Op::infix(multiply, Left) | Op::infix(divide, Left)) }; } -pub fn parse(source: &str) -> Result> { - let mut definitions: Vec = vec![]; +pub fn parse_file(path: &Path) -> Result> { + let source = fs::read_to_string(&path).expect("could not read source file"); + let module_path = ModulePath::from(path); + let mut module = parse_as_module(&source, module_path)?; + module.file = Some(path.to_owned()); + Ok(module) +} - let pairs = KrParser::parse(Rule::source_file, source)?; +pub fn parse_as_module(source: &str, path: ModulePath) -> Result> { + let mut pairs = KrParser::parse(Rule::source_file, &source)?; + + assert!(pairs.len() == 1); + let module = parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path); + + Ok(module) +} + +pub fn parse_module(pair: Pair, path: ModulePath) -> Module { + assert!(pair.as_rule() == Rule::module_items); + + let mut module = Module::new(path); + + let pairs = pair.into_inner(); for pair in pairs { match pair.as_rule() { - Rule::source_file => { - let pairs = pair.into_inner(); - for pair in pairs { - match pair.as_rule() { - Rule::definition => { - let definition = parse_definition(pair.into_inner().next().unwrap()); - definitions.push(definition); - } - Rule::EOI => {} - _ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()), - } - } + Rule::definition => { + let def = parse_definition(pair.into_inner().next().unwrap()); + module.definitions.push(def); } - _ => eprintln!("unexpected top-level rule {:?}", pair.as_rule()), + Rule::use_statement => { + let path = parse_import_path(pair.into_inner().next().unwrap()); + module.imports.push(path); + } + _ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()), } } - Ok(Ast::Module(definitions)) + module } fn parse_block(pair: Pair) -> Block { @@ -73,6 +93,12 @@ fn parse_statement(pair: Pair) -> Statement { let expr = parse_expression(pairs.next().unwrap()); Statement::AssignStatement(identifier, expr) } + Rule::declare_statement => { + let mut pairs = pair.into_inner(); + let identifier = pairs.next().unwrap().as_str().to_string(); + let expr = parse_expression(pairs.next().unwrap()); + Statement::DeclareStatement(identifier, expr) + } Rule::return_statement => { let expr = if let Some(pair) = pair.into_inner().next() { Some(parse_expression(pair)) @@ -85,10 +111,32 @@ fn parse_statement(pair: Pair) -> Statement { let call = parse_call(pair.into_inner().next().unwrap()); Statement::CallStatement(call) } + Rule::use_statement => { + let path = parse_import_path(pair.into_inner().next().unwrap()); + Statement::UseStatement(path) + } + Rule::if_statement => { + let mut pairs = pair.into_inner(); + let condition = parse_expression(pairs.next().unwrap()); + let block = parse_block(pairs.next().unwrap()); + Statement::IfStatement(condition, block) + } + Rule::while_statement => { + let mut pairs = pair.into_inner(); + let condition = parse_expression(pairs.next().unwrap()); + let block = parse_block(pairs.next().unwrap()); + Statement::WhileStatement(Box::new(condition), Box::new(block)) + } _ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()), } } +type ImportPath = ModulePath; + +fn parse_import_path(pair: Pair) -> ImportPath { + ModulePath::from(pair.as_str()) +} + fn parse_call(pair: Pair) -> Call { let mut pairs = pair.into_inner(); // TODO: support calls on more than identifiers (needs grammar change) @@ -117,9 +165,26 @@ fn parse_expression(pair: Pair) -> Expr { .parse() .unwrap(), ), - Rule::ident => Expr::Identifier(primary.as_str().to_string()), Rule::expr => parse_expression(primary), + Rule::ident => Expr::Identifier(primary.as_str().to_string()), Rule::call => Expr::Call(Box::new(parse_call(primary))), + Rule::block => Expr::Block(Box::new(parse_block(primary))), + Rule::if_expr => { + let mut pairs = primary.into_inner(); + let condition = parse_expression(pairs.next().unwrap()); + let true_block = parse_block(pairs.next().unwrap()); + let else_value = parse_expression(pairs.next().unwrap()); + Expr::IfExpr( + Box::new(condition), + Box::new(true_block), + Box::new(else_value), + ) + } + Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() { + "true" => true, + "false" => false, + _ => unreachable!(), + }), _ => unreachable!( "Unexpected rule '{:?}' in primary expression", primary.as_rule() @@ -131,6 +196,9 @@ fn parse_expression(pair: Pair) -> Expr { Rule::subtract => BinaryOperator::Sub, Rule::multiply => BinaryOperator::Mul, Rule::divide => BinaryOperator::Div, + Rule::modulo => BinaryOperator::Modulo, + Rule::equal => BinaryOperator::Equal, + Rule::not_equal => BinaryOperator::NotEqual, _ => unreachable!(), }; Expr::BinaryExpression(Box::new(lhs), operator, Box::new(rhs)) @@ -141,14 +209,15 @@ fn parse_expression(pair: Pair) -> Expr { fn parse_parameter(pair: Pair) -> Parameter { assert!(pair.as_rule() == Rule::parameter); let mut pair = pair.into_inner(); - let name: String = pair.next().unwrap().as_str().to_string(); + let name = pair.next().unwrap().as_str().to_string(); let typ = Type::from(pair.next().unwrap().as_str()); Parameter { name, typ } } -fn parse_definition(pair: Pair) -> Ast { +fn parse_definition(pair: Pair) -> Definition { match pair.as_rule() { Rule::func_def => { + let line_col = pair.line_col(); let mut pairs = pair.into_inner(); let name = pairs.next().unwrap().as_str().to_string(); let parameters: Vec = pairs @@ -169,11 +238,12 @@ fn parse_definition(pair: Pair) -> Ast { }; let body = parse_block(pair); let body = Box::new(body); - Ast::FunctionDefinition(FunctionDefinition { + Definition::FunctionDefinition(FunctionDefinition { name, parameters, return_type, body, + line_col, }) } _ => panic!("unexpected node for definition: {:?}", pair.as_rule()), diff --git a/src/parsing/backend/tree_sitter/lib.rs b/src/parsing/backend/tree_sitter/lib.rs new file mode 100644 index 0000000..775fdf2 --- /dev/null +++ b/src/parsing/backend/tree_sitter/lib.rs @@ -0,0 +1,222 @@ +use tree_sitter::{self, Language, Parser, TreeCursor}; + +enum Ast { + FuncDef(FuncDef), + Expr(Expr), + Module(Vec), + Block(Vec, Option), + Statement(Statement), +} + +enum BinaryOperator { + Add, + Sub, + Mul, + Div, +} + +enum Expr { + BinaryExpression(Box, BinaryOperator, Box), +} + +enum Statement { + AssignStatement(Identifier, Expr), +} + +type Identifier = String; +type Type = String; + +struct Parameter { + name: Identifier, + typ: Type, +} + +struct FuncDef { + name: Identifier, + parameters: Vec, + return_type: Option, + body: Box, +} + +#[derive(Debug)] +struct AstError { + message: String, +} + +impl AstError { + fn new(message: &str) -> Self { + AstError { + message: message.into(), + } + } +} + +extern "C" { + fn tree_sitter_krone() -> Language; +} + +struct TreeCursorChildrenIter<'a, A: AsRef<[u8]>> { + source: A, + cursor: &'a mut TreeCursor<'a>, + on_child: bool, +} + +impl<'a, A: AsRef<[u8]>> Iterator for TreeCursorChildrenIter<'a, A> { + type Item = Result; + + fn next(&mut self) -> Option { + if self.on_child { + if self.cursor.goto_next_sibling() { + Some(parse_from_cursor(&self.source, self.cursor)) + } else { + self.cursor.goto_parent(); + None + } + } else { + if self.cursor.goto_first_child() { + self.on_child = true; + Some(parse_from_cursor(&self.source, self.cursor)) + } else { + None + } + } + } +} + +fn iter_children<'a, A: AsRef<[u8]>>( + source: A, + cursor: &'a mut TreeCursor<'a>, +) -> TreeCursorChildrenIter<'a, A> { + TreeCursorChildrenIter { + source, + cursor, + on_child: false, + } +} + +fn parse_from_cursor<'a>( + source: impl AsRef<[u8]>, + cursor: &'a mut TreeCursor<'a>, +) -> Result { + match cursor.node().kind() { + "block" => { + let mut statements = Vec::new(); + let mut value = None; + + for child in iter_children(source, cursor) { + match child.unwrap() { + Ast::Statement(statement) => { + if value.is_none() { + statements.push(statement); + } else { + return Err(AstError::new( + "cannot have a statement after an expression in a block", + )); + // perhaps there is a missing semicolon ; + } + } + Ast::Expr(expr) => value = Some(expr), + _ => return Err(AstError::new("invalid node type")), + }; + } + + let block = Ast::Block(statements, value); + Ok(block) + } + + "function_definition" => { + // 1: name + assert!(cursor.goto_first_child()); + assert!(cursor.field_name() == Some("name")); + let name: String = cursor + .node() + .utf8_text(source.as_ref()) + .expect("utf8 error") + .into(); + + // 2: parameters + assert!(cursor.goto_next_sibling()); + assert!(cursor.field_name() == Some("parameters")); + let mut parameters = Vec::new(); + + if cursor.goto_first_child() { + loop { + let param = cursor.node(); + + assert!(cursor.goto_first_child()); + let name = cursor + .node() + .utf8_text(source.as_ref()) + .expect("utf8 error") + .into(); + + assert!(cursor.goto_next_sibling()); + let typ = cursor + .node() + .utf8_text(source.as_ref()) + .expect("utf8 error") + .into(); + + cursor.goto_parent(); + + parameters.push(Parameter { name, typ }); + + if !cursor.goto_next_sibling() { + break; + } + } + + cursor.goto_parent(); + } + + // 3: return type + assert!(cursor.goto_next_sibling()); + assert!(cursor.field_name() == Some("return_type")); + let return_type = Some( + cursor + .node() + .utf8_text(source.as_ref()) + .expect("utf8 error") + .into(), + ); + + // 4: body + assert!(cursor.goto_next_sibling()); + assert!(cursor.field_name() == Some("body")); + let body = parse_from_cursor(source, cursor).unwrap(); + let body = Box::new(body); + + Ok(Ast::FuncDef(FuncDef { + name, + parameters, + return_type, + body, + })) + } + + _ => panic!("unexpected node kind: {}", cursor.node().kind()), + } +} + +fn parse_with_tree_sitter(source: impl AsRef<[u8]>) -> Result { + let mut parser = Parser::new(); + let language = unsafe { tree_sitter_krone() }; + parser.set_language(language).unwrap(); + + let tree = parser.parse(&source, None).unwrap(); + + let mut cursor = tree.walk(); + let node = cursor.node(); + assert!(node.kind() == "source_file"); + let mut top_level_nodes = Vec::new(); + + for node in iter_children(source, &mut cursor) { + let node = node.unwrap(); + match node { + Ast::FuncDef(_) => top_level_nodes.push(node), + _ => panic!("unexpected top-level node type"), + }; + } + + Ok(Ast::Module(top_level_nodes)) +} diff --git a/src/parsing/grammar.pest b/src/parsing/grammar.pest deleted file mode 100644 index bf80948..0000000 --- a/src/parsing/grammar.pest +++ /dev/null @@ -1,52 +0,0 @@ -// This file is just a little test of pest.rs - -source_file = { SOI ~ definition* ~ EOI } - -statement = { assign_statement | return_statement | call_statement } -assign_statement = { "set" ~ ident ~ "=" ~ expr ~ ";" } -return_statement = { "return" ~ expr? ~ ";" } -call_statement = { call ~ ";" } - -// Function calls -call = { ident ~ "(" ~ args ~ ")" } -args = { (expr ~ ",")* ~ expr? } - -definition = { func_def } - -func_def = { "fn" ~ ident ~ "(" ~ parameters ~ ")" ~ typ? ~ block } -parameters = { - (parameter ~ ",")* ~ (parameter)? -} -parameter = { ident ~ ":" ~ typ } - -block = { "{" ~ statement* ~ expr? ~ "}" } - -// Operators -infix = _{ add | subtract | multiply | divide } -add = { "+" } -subtract = { "-" } -multiply = { "*" } -divide = { "/" } - -prefix = _{ not } -not = { "!" } - -expr = { prefix? ~ atom ~ (infix ~ prefix? ~ atom)* } -atom = _{ call | ident | literal | "(" ~ expr ~ ")" } - -ident = @{ (ASCII_ALPHA | "_")+ } -typ = _{ ident } - -// Literals -literal = _{ float_literal | integer_literal | string_literal } -string_literal = ${ "\"" ~ string_content ~ "\"" } -string_content = @{ char* } -char = { - !("\"" | "\\") ~ ANY - | "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t") - | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) -} -integer_literal = @{ ASCII_DIGIT+ } -float_literal = @{ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) ~ "." ~ ASCII_DIGIT* } - -WHITESPACE = _{ " " | "\n" | "\t" } diff --git a/src/parsing/mod.rs b/src/parsing/mod.rs index 6d354dc..56e5061 100644 --- a/src/parsing/mod.rs +++ b/src/parsing/mod.rs @@ -1,38 +1,5 @@ -pub mod pest; +mod backend; -pub use self::pest::parse; +pub use self::backend::pest::{parse_file, parse_module}; -mod tests { - #[test] - fn test_addition_function() { - use crate::ast::*; - use crate::parsing::pest::parse; - - let source = "fn add(a: int, b: int) int { a + b }"; - - let ast = Ast::FunctionDefinition(FunctionDefinition { - name: Identifier::from("add"), - parameters: vec![ - Parameter { - name: Identifier::from("a"), - typ: Type::Int, - }, - Parameter { - name: Identifier::from("b"), - typ: Type::Int, - }, - ], - return_type: Some(Type::Int), - body: Box::new(Block { - statements: vec![], - value: Some(Expr::BinaryExpression( - Box::new(Expr::Identifier(Identifier::from("a"))), - BinaryOperator::Add, - Box::new(Expr::Identifier(Identifier::from("b"))), - )), - }), - }); - - assert_eq!(parse(source).unwrap(), Ast::Module(vec![ast])); - } -} +mod tests; diff --git a/src/parsing/tests.rs b/src/parsing/tests.rs new file mode 100644 index 0000000..bc655a3 --- /dev/null +++ b/src/parsing/tests.rs @@ -0,0 +1,44 @@ +#[test] +fn test_addition_function() { + use crate::parsing::backend::pest::parse_as_module; + use crate::{ + ast::module::{Module, ModulePath}, + ast::*, + typing::Type, + }; + + let source = "fn add(a: int, b: int) int { a + b }"; + let path = ModulePath::from("test"); + let module = parse_as_module(&source, path.clone()).expect("parsing error"); + + let expected_module = Module { + file: None, + imports: vec![], + definitions: vec![Definition::FunctionDefinition(FunctionDefinition { + name: Identifier::from("add"), + parameters: vec![ + Parameter { + name: Identifier::from("a"), + typ: Type::Int, + }, + Parameter { + name: Identifier::from("b"), + typ: Type::Int, + }, + ], + return_type: Some(Type::Int), + body: Box::new(Block { + statements: vec![], + value: Some(Expr::BinaryExpression( + Box::new(Expr::Identifier(Identifier::from("a"))), + BinaryOperator::Add, + Box::new(Expr::Identifier(Identifier::from("b"))), + )), + }), + line_col: (1, 1), + })], + path, + }; + + assert_eq!(module, expected_module); +} diff --git a/src/typing/mod.rs b/src/typing/mod.rs new file mode 100644 index 0000000..ce39328 --- /dev/null +++ b/src/typing/mod.rs @@ -0,0 +1,452 @@ +use std::collections::HashMap; + +use crate::ast::{ + module::{Module, ModulePath}, + *, +}; + +#[derive(Debug, PartialEq, Clone)] +pub enum Type { + Bool, + Int, + Float, + Unit, + Str, + Custom(Identifier), +} + +impl From<&str> for Type { + fn from(value: &str) -> Self { + match value { + "int" => Type::Int, + "float" => Type::Float, + _ => Type::Custom(Identifier::from(value)), + } + } +} + +impl FunctionDefinition { + fn signature(&self) -> (Vec, Type) { + let return_type = self.return_type.unwrap_or(Type::Unit); + let params_types = self.parameters.iter().map(|p| p.typ).collect(); + (params_types, return_type) + } +} + +impl Module { + pub fn type_check(&self) -> Result<(), TypeError> { + let mut ctx = TypeContext::new(self.path); + ctx.file = self.file.clone(); + + // Register all function signatures + for Definition::FunctionDefinition(func) in &self.definitions { + if let Some(previous) = ctx.functions.insert(func.name.clone(), func.signature()) { + todo!("handle redefinition of function or identical function names across different files"); + } + } + + // TODO: add signatures of imported functions (even if they have not been checked) + + // Type-check the function bodies + for Definition::FunctionDefinition(func) in &self.definitions { + func.typ(&mut ctx)?; + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct TypeError { + file: Option, + module: ModulePath, + function: Option, + kind: TypeErrorKind, +} + +impl std::fmt::Display for TypeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Error\n")?; + if let Some(path) = &self.file { + f.write_fmt(format_args!(" in file {}\n", path.display()))?; + } + f.write_fmt(format_args!(" in module {}\n", self.module))?; + if let Some(name) = &self.function { + f.write_fmt(format_args!(" in function {}\n", name))?; + } + f.write_fmt(format_args!("{:#?}", self.kind))?; + Ok(()) + } +} + +#[derive(Default)] +struct TypeErrorBuilder { + file: Option, + module: Option, + function: Option, + kind: Option, +} + +impl TypeError { + fn builder() -> TypeErrorBuilder { + TypeErrorBuilder::default() + } +} + +impl TypeErrorBuilder { + fn context(mut self, ctx: &TypeContext) -> Self { + self.file = ctx.file.clone(); + self.module = Some(ctx.module.clone()); + self.function = ctx.function.clone(); + self + } + + fn kind(mut self, kind: TypeErrorKind) -> Self { + self.kind = Some(kind); + self + } + + fn build(self) -> TypeError { + TypeError { + file: self.file, + module: self.module.expect("TypeError builder is missing module"), + function: self.function, + kind: self.kind.expect("TypeError builder is missing kind"), + } + } +} + +#[derive(Debug)] +pub enum TypeErrorKind { + InvalidBinaryOperator { + operator: BinaryOperator, + lht: Type, + rht: Type, + }, + BlockTypeDoesNotMatchFunctionType { + block_type: Type, + function_type: Type, + }, + ReturnTypeDoesNotMatchFunctionType { + function_type: Type, + return_type: Type, + }, + UnknownIdentifier { + identifier: String, + }, + AssignmentMismatch { + lht: Type, + rht: Type, + }, + AssignUndeclared, + VariableRedeclaration, + ReturnStatementsMismatch, + UnknownFunctionCalled(Identifier), + WrongFunctionArguments, + ConditionIsNotBool, + IfElseMismatch, +} + +pub struct TypeContext { + pub file: Option, + pub module: ModulePath, + pub function: Option, + pub functions: HashMap, Type)>, + pub variables: HashMap, +} + +impl TypeContext { + pub fn new(path: ModulePath) -> Self { + TypeContext { + file: None, + module: path, + function: None, + functions: Default::default(), + variables: Default::default(), + } + } +} + +/// Trait for nodes which have a deducible type. +pub trait Typ { + /// Try to resolve the type of the node. + fn typ(&self, ctx: &mut TypeContext) -> Result; +} + +impl Typ for FunctionDefinition { + fn typ(&self, ctx: &mut TypeContext) -> Result { + let func = self; + + ctx.function = Some(func.name.clone()); + + for param in &func.parameters { + ctx.variables.insert(param.name.clone(), param.typ.clone()); + } + + let body_type = &func.body.typ(ctx)?; + + // If the return type is not specified, it is unit. + let func_return_type = match &func.return_type { + Some(typ) => typ, + None => &Type::Unit, + }; + + // Check coherence with the body's type. + if *func_return_type != *body_type { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType { + block_type: body_type.clone(), + function_type: func_return_type.clone(), + }) + .build()); + } + + // Check coherence with return statements. + for statement in &func.body.statements { + if let Statement::ReturnStatement(value) = statement { + let ret_type = match value { + Some(expr) => expr.typ(ctx)?, + None => Type::Unit, + }; + if ret_type != *func_return_type { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { + function_type: func_return_type.clone(), + return_type: ret_type, + }) + .build()); + } + } + } + + Ok(func_return_type.clone()) + } +} + +impl Typ for Block { + fn typ(&self, ctx: &mut TypeContext) -> Result { + let mut return_typ: Option = None; + + // Check declarations and assignments. + for statement in &self.statements { + match statement { + Statement::DeclareStatement(ident, expr) => { + let typ = expr.typ(ctx)?; + if let Some(_typ) = ctx.variables.insert(ident.clone(), typ) { + // TODO: Shadowing? (illegal for now) + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::VariableRedeclaration) + .build()); + } + } + Statement::AssignStatement(ident, expr) => { + let rhs_typ = expr.typ(ctx)?; + let Some(lhs_typ) = ctx.variables.get(ident) else { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::AssignUndeclared) + .build()); + }; + + // Ensure same type on both sides. + if rhs_typ != *lhs_typ { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::AssignmentMismatch { + lht: lhs_typ.clone(), + rht: rhs_typ.clone(), + }) + .build()); + } + } + Statement::ReturnStatement(maybe_expr) => { + let expr_typ = if let Some(expr) = maybe_expr { + expr.typ(ctx)? + } else { + Type::Unit + }; + if let Some(typ) = &return_typ { + if expr_typ != *typ { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::ReturnStatementsMismatch) + .build()); + } + } else { + return_typ = Some(expr_typ); + } + } + Statement::CallStatement(call) => { + call.typ(ctx)?; + } + Statement::UseStatement(_path) => { + // TODO: import the signatures (and types) + } + Statement::IfStatement(cond, block) => { + if cond.typ(ctx)? != Type::Bool { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::ConditionIsNotBool) + .build()); + } + block.typ(ctx)?; + } + Statement::WhileStatement(cond, block) => { + if cond.typ(ctx)? != Type::Bool { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::ConditionIsNotBool) + .build()); + } + block.typ(ctx)?; + } + } + } + + // Check if there is an expression at the end of the block. + if let Some(expr) = &self.value { + expr.typ(ctx) + } else { + Ok(Type::Unit) + } + + // TODO/FIXME: find a way to return `return_typ` so that the + // top-level block (the function) can check if this return type + // (and eventually those from other block) matches the type of + // the function. + } +} + +impl Typ for Call { + fn typ(&self, ctx: &mut TypeContext) -> Result { + match &self.callee { + Expr::Identifier(ident) => { + let signature = match ctx.functions.get(ident) { + Some(sgn) => sgn.clone(), + None => { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::UnknownFunctionCalled(ident.clone())) + .build()) + } + }; + let (params_types, func_type) = signature; + + // Collect arg types. + let mut args_types: Vec = vec![]; + for arg in &self.args { + let typ = arg.typ(ctx)?; + args_types.push(typ.clone()); + } + + if args_types == *params_types { + Ok(func_type.clone()) + } else { + Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::WrongFunctionArguments) + .build()) + } + } + _ => unimplemented!("cannot call on expression other than identifier"), + } + } +} + +impl Typ for Expr { + fn typ(&self, ctx: &mut TypeContext) -> Result { + match self { + Expr::Identifier(identifier) => { + if let Some(typ) = ctx.variables.get(identifier) { + Ok(typ.clone()) + } else { + Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::UnknownIdentifier { + identifier: identifier.clone(), + }) + .build()) + } + } + Expr::BooleanLiteral(_) => Ok(Type::Bool), + Expr::IntegerLiteral(_) => Ok(Type::Int), + Expr::FloatLiteral(_) => Ok(Type::Float), + Expr::BinaryExpression(lhs, op, rhs) => match op { + BinaryOperator::Add + | BinaryOperator::Sub + | BinaryOperator::Mul + | BinaryOperator::Div => { + let left_type = &lhs.typ(ctx)?; + let right_type = &rhs.typ(ctx)?; + match (left_type, right_type) { + (Type::Int, Type::Int) => Ok(Type::Int), + (Type::Float, Type::Float) => Ok(Type::Float), + (_, _) => Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidBinaryOperator { + operator: op.clone(), + lht: left_type.clone(), + rht: right_type.clone(), + }) + .build()), + } + } + BinaryOperator::Equal | BinaryOperator::NotEqual => { + let lhs_type = lhs.typ(ctx)?; + let rhs_type = rhs.typ(ctx)?; + if lhs_type != rhs_type { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidBinaryOperator { + operator: op.clone(), + lht: lhs_type.clone(), + rht: rhs_type.clone(), + }) + .build()); + } + Ok(Type::Bool) + } + BinaryOperator::Modulo => { + let lhs_type = lhs.typ(ctx)?; + let rhs_type = lhs.typ(ctx)?; + match (&lhs_type, &rhs_type) { + (Type::Int, Type::Int) => Ok(Type::Int), + _ => Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidBinaryOperator { + operator: op.clone(), + lht: lhs_type.clone(), + rht: rhs_type.clone(), + }) + .build()), + } + } + }, + Expr::StringLiteral(_) => Ok(Type::Str), + Expr::Call(call) => call.typ(ctx), + Expr::Block(block) => block.typ(ctx), + Expr::IfExpr(cond, true_block, else_value) => { + if cond.typ(ctx)? != Type::Bool { + Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::ConditionIsNotBool) + .build()) + } else { + let true_block_type = true_block.typ(ctx)?; + let else_type = else_value.typ(ctx)?; + if true_block_type != else_type { + Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::IfElseMismatch) + .build()) + } else { + Ok(true_block_type.clone()) + } + } + } + } + } +}