diff --git a/Cargo.toml b/Cargo.toml index 3b50703..5387a84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,10 +5,10 @@ edition = "2021" [dependencies] clap = { version = "4.3.0", features = ["derive"] } -cranelift = "0.100.0" -cranelift-jit = "0.96.3" -cranelift-module = "0.96.3" -cranelift-native = "0.96.3" +cranelift = "0.105.1" +cranelift-jit = "0.105.1" +cranelift-module = "0.105.1" +cranelift-native = "0.105.1" lazy_static = "1.4.0" -pest = "2.6.0" -pest_derive = "2.6.0" +pest = "2.7.4" +pest_derive = "2.7.4" diff --git a/src/ast/expr.rs b/src/ast/expr.rs new file mode 100644 index 0000000..5f776cb --- /dev/null +++ b/src/ast/expr.rs @@ -0,0 +1,34 @@ +use crate::ast::*; + +#[derive(Debug, PartialEq)] +pub enum Expr { + BinaryExpression { + lhs: Box, + op: BinaryOperator, + rhs: Box, + typ: Type, + }, + UnaryExpression { + op: UnaryOperator, + inner: Box, + }, + Identifier { + name: String, + typ: Type, + }, + Call(Box), + Block(Box), + /// Last field is either Expr::Block or Expr::IfExpr + IfExpr { + cond: Box, + then_body: Box, + else_body: Box, + typ: Type, + }, + // Literals + UnitLiteral, + BooleanLiteral(bool), + IntegerLiteral(i64), + FloatLiteral(f64), + StringLiteral(String), +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 05e708e..8348a29 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,10 +1,17 @@ +use crate::typing::Type; use std::path::Path; +pub mod expr; pub mod typed; -pub mod untyped; + +pub use expr::Expr; #[derive(Debug, PartialEq, Clone)] pub enum BinaryOperator { + // Logic + And, + Or, + // Arithmetic Add, Sub, Mul, @@ -14,13 +21,19 @@ pub enum BinaryOperator { NotEqual, } -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Copy, Clone)] pub enum UnaryOperator { + Not, } pub type Identifier = String; -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq)] +pub struct Location { + pub line_col: (usize, usize), +} + +#[derive(Debug, PartialEq, Clone, Default)] pub struct ModulePath { components: Vec, } @@ -65,3 +78,75 @@ impl From<&str> for ModulePath { #[derive(Eq, PartialEq, Debug)] pub struct Import(pub String); + +#[derive(Debug, PartialEq)] +pub enum Statement { + DeclareStatement(Identifier, Box), + AssignStatement(Identifier, Box), + ReturnStatement(Option), + CallStatement(Box), + UseStatement(Box), + IfStatement(Box, Box), + WhileStatement(Box, Box), +} + +#[derive(Debug, PartialEq)] +pub struct Block { + pub statements: Vec, + pub value: Option, + pub typ: Type, +} + +impl Block { + pub fn empty() -> Block { + Block { + typ: Type::Unit, + statements: Vec::with_capacity(0), + value: None, + } + } +} + +#[derive(Debug, PartialEq)] +pub enum Definition { + FunctionDefinition(FunctionDefinition), +} + +#[derive(Debug, PartialEq)] +pub struct FunctionDefinition { + pub name: Identifier, + pub parameters: Vec, + pub return_type: Option, + pub body: Box, + pub location: Location, +} + +#[derive(Debug, PartialEq, Default)] +pub struct Module { + pub file: Option, + pub path: ModulePath, + pub functions: Vec, + pub imports: Vec, +} + +impl Module { + pub fn new(path: ModulePath) -> Self { + Self { + path, + ..Default::default() + } + } +} + +#[derive(Debug, PartialEq)] +pub struct Call { + pub callee: Box, + pub args: Vec, + pub typ: Type, +} + +#[derive(Debug, PartialEq)] +pub struct Parameter { + pub name: Identifier, + pub typ: Type, +} diff --git a/src/ast/typed/expr.rs b/src/ast/typed/expr.rs index 7588907..736bc22 100644 --- a/src/ast/typed/expr.rs +++ b/src/ast/typed/expr.rs @@ -3,60 +3,20 @@ use crate::ast::{ BinaryOperator, UnaryOperator, }; use crate::typing::Type; - -#[derive(Debug, PartialEq)] -pub enum Expr { - BinaryExpression { - lhs: Box, - op: BinaryOperator, - rhs: Box, - typ: Type, - }, - UnaryExpression { - op: UnaryOperator, - inner: Box, - }, - Variable { - name: String, - typ: Type, - }, - Call { - call: Box, - typ: Type, - }, - Block { - block: Box, - typ: Type, - }, - /// Last field is either Expr::Block or Expr::IfExpr - IfExpr { - cond: Box, - then_body: Box, - else_body: Box, - typ: Type, - }, - // Literals - UnitLiteral, - BooleanLiteral(bool), - IntegerLiteral(i64), - FloatLiteral(f64), - StringLiteral(String), -} - impl Expr { pub fn typ(&self) -> &Type { match self { - Expr::BinaryExpression { lhs, op, rhs, typ } => typ, + Expr::BinaryExpression { lhs, op, rhs, typ } => &typ.unwrap(), Expr::UnaryExpression { op, inner } => inner.typ(), // XXX: problems will arise here - Expr::Variable { name, typ } => typ, - Expr::Call { call, typ } => typ, - Expr::Block { block, typ } => typ, + Expr::Variable { name, typ } => &typ.unwrap(), + Expr::Call { call, typ } => &typ.unwrap(), + Expr::Block { block, typ } => &typ.unwrap(), Expr::IfExpr { cond, then_body, else_body, typ, - } => typ, + } => &typ.unwrap(), Expr::UnitLiteral => &Type::Unit, Expr::BooleanLiteral(_) => &Type::Bool, Expr::IntegerLiteral(_) => &Type::Int, diff --git a/src/ast/typed/mod.rs b/src/ast/typed/mod.rs index 505ed3a..033129a 100644 --- a/src/ast/typed/mod.rs +++ b/src/ast/typed/mod.rs @@ -1,47 +1,37 @@ -pub mod expr; - -use crate::typing::Type; -use super::{untyped::Parameter, Identifier, Import}; -use expr::Expr; - -#[derive(Debug, PartialEq)] -pub enum Statement { - DeclareStatement(Identifier, Box), - AssignStatement(Identifier, Box), - ReturnStatement(Option), - CallStatement(Box), - UseStatement(Box), - IfStatement(Box, Box), - WhileStatement(Box, Box), -} - -#[derive(Debug, PartialEq)] -pub struct Block { - pub statements: Vec, - pub value: Option, - typ: Type, -} +use crate::ast::*; impl Block { #[inline] - pub fn typ(&self) -> Type { + pub fn ty(&self) -> Type { + // XXX: Cloning may be expensive -> TypeId? self.typ.clone() } } -#[derive(Debug, PartialEq)] -pub struct FunctionDefinition { - pub name: Identifier, - pub parameters: Vec, - pub return_type: Option, - pub body: Box, - pub line_col: (usize, usize), +impl Expr { + pub fn ty(&self) -> Type { + match self { + Expr::BinaryExpression { + lhs: _, + op: _, + rhs: _, + typ, + } => typ.clone(), + Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here + Expr::Identifier { name: _, typ } => typ.clone(), + Expr::Call(call) => call.typ.clone(), + Expr::Block(block) => block.typ.clone(), + Expr::IfExpr { + cond: _, + then_body: _, + else_body: _, + typ, + } => typ.clone(), + Expr::UnitLiteral => Type::Unit, + Expr::BooleanLiteral(_) => Type::Bool, + Expr::IntegerLiteral(_) => Type::Int, + Expr::FloatLiteral(_) => Type::Float, + Expr::StringLiteral(_) => Type::Str, + } + } } - -#[derive(Debug, PartialEq)] -pub struct Call { - pub callee: Box, - pub args: Vec, - pub typ: Type, -} - diff --git a/src/ast/untyped/expr.rs b/src/ast/untyped/expr.rs deleted file mode 100644 index ff8d1f8..0000000 --- a/src/ast/untyped/expr.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::ast::{ - untyped::{Block, Call}, - Identifier, -}; - -use crate::ast::*; - -#[derive(Debug, PartialEq)] -pub enum Expr { - UnitLiteral, - BinaryExpression(Box, BinaryOperator, Box), - Identifier(Identifier), - Call(Box), - // Literals - BooleanLiteral(bool), - IntegerLiteral(i64), - FloatLiteral(f64), - StringLiteral(String), - Block(Box), - /// Last field is either Expr::Block or Expr::IfExpr - IfExpr(Box, Box, Box), -} diff --git a/src/ast/untyped/mod.rs b/src/ast/untyped/mod.rs deleted file mode 100644 index 9138679..0000000 --- a/src/ast/untyped/mod.rs +++ /dev/null @@ -1,61 +0,0 @@ -pub mod expr; -pub mod module; - -use std::path::Path; - -pub use crate::ast::untyped::expr::Expr; -pub use crate::ast::*; -// TODO: remove all usage of 'Type' in the untyped ast -// (for now it is assumed that anything that parses -// is a Type, but the checking should be done in the typing -// phase) -use crate::typing::Type; - -#[derive(Debug, PartialEq)] -pub enum Definition { - FunctionDefinition(FunctionDefinition), - //StructDefinition(StructDefinition), -} - -#[derive(Debug, PartialEq)] -pub struct Location { - pub file: Box, -} - -#[derive(Debug, PartialEq)] -pub struct FunctionDefinition { - pub name: Identifier, - pub parameters: Vec, - pub return_type: Option, - pub body: Box, - pub line_col: (usize, usize), -} - -#[derive(Debug, PartialEq)] -pub struct Block { - pub statements: Vec, - pub value: Option, -} - -#[derive(Debug, PartialEq)] -pub enum Statement { - DeclareStatement(Identifier, Expr), - AssignStatement(Identifier, Expr), - ReturnStatement(Option), - CallStatement(Call), - UseStatement(Import), - IfStatement(Expr, Block), - WhileStatement(Box, Box), -} - -#[derive(Debug, PartialEq)] -pub struct Call { - pub callee: Box, - pub args: Vec, -} - -#[derive(Debug, PartialEq)] -pub struct Parameter { - pub name: Identifier, - pub typ: Type, -} diff --git a/src/ast/untyped/module.rs b/src/ast/untyped/module.rs deleted file mode 100644 index bf4b54c..0000000 --- a/src/ast/untyped/module.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::{Definition, ModulePath, Import}; - -#[derive(Debug, PartialEq)] -pub struct Module { - pub file: Option, - pub path: ModulePath, - pub definitions: Vec, - pub imports: Vec, -} - -impl Module { - pub fn new(path: ModulePath) -> Self { - Module { - path, - file: None, - definitions: vec![], - imports: vec![], - } - } -} diff --git a/src/jit/mod.rs b/src/jit/mod.rs new file mode 100644 index 0000000..0a9a9b4 --- /dev/null +++ b/src/jit/mod.rs @@ -0,0 +1,465 @@ +use crate::{ + ast::{ + self, BinaryOperator, ModulePath, UnaryOperator, + {expr::Expr, FunctionDefinition, Statement}, + }, + parsing, + typing::Type, +}; +use cranelift::{codegen::ir::UserFuncName, prelude::*}; +use cranelift_jit::{JITBuilder, JITModule}; +use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module}; +use std::{collections::HashMap, ops::Deref}; + +/// The basic JIT class. +pub struct JIT { + /// The function builder context, which is reused across multiple + /// FunctionBuilder instances. + builder_context: FunctionBuilderContext, + + /// The main Cranelift context, which holds the state for codegen. Cranelift + /// separates this from `Module` to allow for parallel compilation, with a + /// context per thread, though this isn't in the simple demo here. + ctx: codegen::Context, + + /// The data description, which is to data objects what `ctx` is to functions. + data_desc: DataDescription, + + /// The module, with the jit backend, which manages the JIT'd functions. + module: JITModule, +} + +impl Default for JIT { + fn default() -> Self { + let mut flag_builder = codegen::settings::builder(); + flag_builder.set("use_colocated_libcalls", "false").unwrap(); + flag_builder.set("is_pic", "false").unwrap(); + + let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { + panic!("host machine is not supported: {}", msg); + }); + + let isa = isa_builder + .finish(settings::Flags::new(flag_builder)) + .unwrap(); + + let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); + + let module = JITModule::new(builder); + + Self { + builder_context: FunctionBuilderContext::new(), + ctx: module.make_context(), + data_desc: DataDescription::new(), + module, + } + } +} + +impl JIT { + /// Compile source code into machine code. + pub fn compile(&mut self, input: &str, dump_clir: bool) -> Result<*const u8, String> { + // Parse the source code into an AST + let Ok(mut ast) = parsing::parse_as_module(input, ModulePath::from("globalmodule")) else { + return Err("Parsing error".to_string()); + }; + + if let Err(e) = ast.type_check() { + return Err(e.to_string()); + }; + + // Translate the AST into Cranelift IR + self.translate(&ast, dump_clir)?; + + // Finalize the functions which we just defined, which resolves any + // outstanding relocations (patching in addresses, now that they're + // available). + self.module.finalize_definitions().unwrap(); + + // We can now retrieve a pointer to the machine code. + if let Some(FuncOrDataId::Func(main_id)) = self.module.get_name("main") { + let code = self.module.get_finalized_function(main_id); + Ok(code) + } else { + Err("no main function".into()) + } + } + + /// Translate language AST into Cranelift IR. + fn translate(&mut self, ast: &ast::Module, dump_clir: bool) -> Result<(), String> { + let mut signatures: Vec = Vec::with_capacity(ast.functions.len()); + let mut func_ids: Vec = Vec::with_capacity(ast.functions.len()); + + // Declare functions + for func in &ast.functions { + // Create the signature + let mut sig = self.module.make_signature(); + + for param in &func.parameters { + assert!(param.typ != Type::Unit); + sig.params.append(&mut Vec::from(¶m.typ)); + } + + if let Some(return_type) = &func.return_type { + if *return_type != Type::Unit { + sig.returns = return_type.into(); + } + }; + + let id: FuncId = self + .module + .declare_function(&func.name, Linkage::Export, &sig) + .map_err(|e| e.to_string())?; + + signatures.push(sig); + func_ids.push(id); + } + + // Translate functions + for (i, func) in ast.functions.iter().enumerate() { + self.ctx.func.signature = signatures[i].clone(); + self.ctx.func.name = UserFuncName::user(0, func_ids[i].as_u32()); + self.translate_function(func)?; + self.module + .define_function(func_ids[i], &mut self.ctx) + .unwrap(); + + if dump_clir { + println!("// {}", func.name); + println!("{}", self.ctx.func.display()); + } + + self.module.clear_context(&mut self.ctx); + } + + Ok(()) + } + + fn translate_function(&mut self, function: &FunctionDefinition) -> Result<(), String> { + // Create the builder to build a function. + let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); + + // Create the entry block, to start emitting code in. + let entry_block = builder.create_block(); + + // Since this is the entry block, add block parameters corresponding to + // the function's parameters. + builder.append_block_params_for_function_params(entry_block); + + // Tell the builder to emit code in this block. + builder.switch_to_block(entry_block); + + // And, tell the builder that this block will have no further + // predecessors. Since it's the entry block, it won't have any + // predecessors. + builder.seal_block(entry_block); + + // Walk the AST and declare all implicitly-declared variables. + let mut variables = HashMap::::default(); // TODO: actually do this + + // Add a variable for each parameter. + let param_values: Box<[Value]> = builder.block_params(entry_block).into(); + assert!(param_values.len() == function.parameters.len()); + for (i, param) in function.parameters.iter().enumerate() { + let var = Variable::from_u32(variables.len() as u32); + variables.insert(param.name.clone(), var); + let value = param_values[i]; + builder.declare_var(var, param.typ.clone().into()); + builder.def_var(var, value); + } + + // Now translate the statements of the function body. + let mut translator = FunctionTranslator { + builder, + variables, + module: &mut self.module, + }; + for stmt in &function.body.statements { + translator.translate_statement(stmt); + } + + // Emit the final return instruction. + if let Some(return_expr) = &function.body.value { + let return_value = translator.translate_expr(&return_expr); + translator.builder.ins().return_(&[return_value]); + } else { + translator.builder.ins().return_(&[]); + } + + // Tell the builder we're done with this function. + translator.builder.finalize(); + Ok(()) + } +} + +impl From for types::Type { + fn from(value: crate::typing::Type) -> Self { + match value { + Type::Bool => types::I8, + Type::Int => types::I32, + Type::Float => types::F32, + Type::Unit => unreachable!(), + Type::Str => todo!(), + Type::Custom(_) => todo!(), + Type::Function { + params: _, + returns: _, + } => todo!(), + Type::Undefined => unreachable!(), + } + } +} + +impl From<&Type> for Vec { + fn from(value: &Type) -> Self { + match value { + Type::Bool => vec![AbiParam::new(types::I8)], + Type::Int => vec![AbiParam::new(types::I32)], + Type::Float => vec![AbiParam::new(types::F32)], + _ => unimplemented!(), + } + } +} + +/// A collection of state used for translating from AST nodes +/// into Cranelift IR. +struct FunctionTranslator<'a> { + builder: FunctionBuilder<'a>, + variables: HashMap, + module: &'a mut JITModule, +} + +impl<'a> FunctionTranslator<'a> { + fn translate_statement(&mut self, stmt: &Statement) -> Option { + match stmt { + Statement::AssignStatement(name, expr) => { + // `def_var` is used to write the value of a variable. Note that + // variables can have multiple definitions. Cranelift will + // convert them into SSA form for itself automatically. + let new_value = self.translate_expr(expr); + let variable = self.variables.get(name).unwrap(); + self.builder.def_var(*variable, new_value); + Some(new_value) + } + + Statement::DeclareStatement(name, expr) => { + let value = self.translate_expr(expr); + let variable = Variable::from_u32(self.variables.len() as u32); + self.builder.declare_var(variable, expr.ty().into()); + self.builder.def_var(variable, value); + self.variables.insert(name.clone(), variable); + Some(value) + } + + Statement::ReturnStatement(maybe_expr) => { + // TODO: investigate tail call + let values = if let Some(expr) = maybe_expr { + vec![self.translate_expr(expr)] + } else { + // XXX: urgh + Vec::with_capacity(0) + }; + + // XXX: Should we pass multiple values ? + self.builder.ins().return_(&values); + + None + } + + Statement::CallStatement(call) => self.translate_call(call), + + Statement::UseStatement(_) => todo!(), + + Statement::IfStatement(cond, then_body) => { + let condition_value = self.translate_expr(cond); + + let then_block = self.builder.create_block(); + let merge_block = self.builder.create_block(); + + self.builder + .ins() + .brif(condition_value, then_block, &[], merge_block, &[]); + + self.builder.switch_to_block(then_block); + self.builder.seal_block(then_block); + self.translate_block(then_body); + self.builder.ins().jump(merge_block, &[]); + + self.builder.switch_to_block(merge_block); + self.builder.seal_block(merge_block); + + None + } + + Statement::WhileStatement(_, _) => todo!(), + } + } + + fn translate_expr(&mut self, expr: &Expr) -> Value { + match expr { + Expr::UnitLiteral => unreachable!(), + + Expr::BooleanLiteral(imm) => self.builder.ins().iconst(types::I8, i64::from(*imm)), + Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm), + Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm), + + Expr::StringLiteral(_) => todo!(), + + Expr::BinaryExpression { + lhs, + op, + rhs, + typ: _, + } => { + let lhs_value = self.translate_expr(lhs); + let rhs_value = self.translate_expr(rhs); + + match (lhs.ty(), lhs.ty()) { + (Type::Int, Type::Int) => match op { + BinaryOperator::Add => self.builder.ins().iadd(lhs_value, rhs_value), + BinaryOperator::Sub => self.builder.ins().isub(lhs_value, rhs_value), + BinaryOperator::Mul => self.builder.ins().imul(lhs_value, rhs_value), + // TODO: investigate division (case rhs <= 0) + BinaryOperator::Div => self.builder.ins().udiv(lhs_value, rhs_value), + BinaryOperator::Modulo => todo!(), + + BinaryOperator::Equal => { + self.builder.ins().icmp(IntCC::Equal, lhs_value, rhs_value) + } + BinaryOperator::NotEqual => { + self.builder + .ins() + .icmp(IntCC::NotEqual, lhs_value, rhs_value) + } + _ => unreachable!(), + }, + (Type::Bool, Type::Bool) => match op { + // XXX: Is min and max ok or should it be something else? + BinaryOperator::And => self.builder.ins().umin(lhs_value, rhs_value), + BinaryOperator::Or => self.builder.ins().umax(lhs_value, rhs_value), + _ => unreachable!(), + }, + _ => unimplemented!(), + } + } + + Expr::IfExpr { + cond, + then_body, + else_body, + typ, + } => { + let condition_value = self.translate_expr(cond); + + let then_block = self.builder.create_block(); + let else_block = self.builder.create_block(); + let merge_block = self.builder.create_block(); + + // If-else constructs in the language have a return value. + // In traditional SSA form, this would produce a PHI between + // the then and else bodies. Cranelift uses block parameters, + // so set up a parameter in the merge block, and we'll pass + // the return values to it from the branches. + self.builder + .append_block_param(merge_block, typ.clone().into()); + + // Test the if condition and conditionally branch. + self.builder + .ins() + .brif(condition_value, then_block, &[], else_block, &[]); + + self.builder.switch_to_block(then_block); + self.builder.seal_block(then_block); + for stmt in &then_body.statements { + self.translate_statement(&stmt); + } + + let then_return_value = match &then_body.value { + Some(val) => vec![self.translate_expr(val)], + None => Vec::with_capacity(0), + }; + + // Jump to the merge block, passing it the block return value. + self.builder.ins().jump(merge_block, &then_return_value); + + self.builder.switch_to_block(else_block); + self.builder.seal_block(else_block); + // XXX: the else can be just an expression: do we always need to + // make a second branch in that case? Or leave it to cranelift? + let else_return_value = match **else_body { + Expr::UnitLiteral => Vec::with_capacity(0), + _ => vec![self.translate_expr(else_body)], + }; + + // Jump to the merge block, passing it the block return value. + self.builder.ins().jump(merge_block, &else_return_value); + + // Switch to the merge block for subsequent statements. + self.builder.switch_to_block(merge_block); + + // We've now seen all the predecessors of the merge block. + self.builder.seal_block(merge_block); + + // Read the value of the if-else by reading the merge block + // parameter. + let phi = self.builder.block_params(merge_block)[0]; + + phi + } + + Expr::UnaryExpression { op, inner } => { + let inner_value = self.translate_expr(inner); + match op { + // XXX: This should not be a literal translation + UnaryOperator::Not => { + let one = self.builder.ins().iconst(types::I8, 1); + self.builder.ins().isub(one, inner_value) + } + } + } + + Expr::Identifier { name, typ: _ } => { + self.builder.use_var(*self.variables.get(name).unwrap()) + } + + Expr::Call(call) => self.translate_call(call).unwrap(), + + Expr::Block(block) => self.translate_block(block).unwrap(), + } + } + + fn translate_block(&mut self, block: &ast::Block) -> Option { + for stmt in &block.statements { + self.translate_statement(stmt); + } + if let Some(block_value) = &block.value { + Some(self.translate_expr(block_value)) + } else { + None + } + } + + fn translate_call(&mut self, call: &ast::Call) -> Option { + match call.callee.deref() { + Expr::Identifier { name, typ: _ } => { + let func_ref = if let Some(func_or_data_id) = self.module.get_name(name.as_ref()) { + if let FuncOrDataId::Func(func_id) = func_or_data_id { + self.module.declare_func_in_func(func_id, self.builder.func) + } else { + panic!() + } + } else { + todo!() + }; + + let args: Vec = call.args.iter().map(|a| self.translate_expr(a)).collect(); + + // TODO: handle the return value of the function + let call_inst = self.builder.ins().call(func_ref, &args); + let results = self.builder.inst_results(call_inst); + Some(results[0]) + } + _ => unimplemented!(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 89cfa94..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod ast; -mod typing; -mod jit; -mod parsing; diff --git a/src/main.rs b/src/main.rs index bfe1518..ad868fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ -mod ast; -mod parsing; -mod typing; +pub mod ast; +pub mod jit; +pub mod parsing; +pub mod typing; use clap::{Parser, Subcommand}; -use crate::ast::untyped::module::Module; +use crate::ast::Module; /// Experimental compiler for lila #[derive(Parser, Debug)] @@ -18,53 +19,80 @@ struct Cli { #[derive(Subcommand, Debug)] enum Commands { Parse { - /// Path to the source files + /// Paths to the source files files: Vec, /// Dump the AST to stdout #[arg(long)] dump_ast: bool, - - /// Add missing return types in the AST - #[arg(long)] - type_check: bool, }, + TypeCheck { + /// Paths to the source files + files: Vec, + + /// Dump the AST to stdout + #[arg(long)] + dump_ast: bool, + }, + Compile { + /// Paths to the source files + files: Vec, + + /// Dump the CLIR to stdout + #[arg(long)] + dump_clir: bool, + }, +} + +fn parse(files: &Vec) -> Vec { + let paths = files.iter().map(std::path::Path::new); + paths + .map(|path| match parsing::parse_file(&path) { + Ok(module) => module, + Err(e) => panic!("Parsing error: {:#?}", e), + }) + .collect() +} + +fn check(modules: &mut Vec) { + for module in modules { + if let Err(e) = module.type_check() { + eprintln!("{}", e); + return; + } + } } fn main() { let cli = Cli::parse(); match &cli.command { - Commands::Parse { - files, - dump_ast, - type_check, - } => { - let paths = files.iter().map(std::path::Path::new); - let modules: Vec = paths - .map(|path| match parsing::parse_file(&path) { - Ok(module) => module, - Err(e) => panic!("Parsing error: {:#?}", e), - }) - .collect(); - - if *type_check { - for module in &modules { - if let Err(e) = module.type_check() { - eprintln!("{}", e); - return; - } - } - } - + Commands::Parse { files, dump_ast } => { + let modules = parse(files); if *dump_ast { for module in &modules { println!("{:#?}", &module); } - return; } - println!("Parsing OK"); } + Commands::TypeCheck { files, dump_ast } => { + let mut modules = parse(files); + check(&mut modules); + if *dump_ast { + for module in &modules { + println!("{:#?}", &module); + } + } + } + Commands::Compile { files, dump_clir } => { + let mut jit = jit::JIT::default(); + for file in files { + match jit.compile(std::fs::read_to_string(file).unwrap().as_str(), *dump_clir) { + Err(e) => eprintln!("{}", e), + Ok(_code) => println!("Compiled {}", file), + } + } + } } } diff --git a/src/parsing/backend/pest/grammar.pest b/src/parsing/backend/pest/grammar.pest index 1a6aeec..e890180 100644 --- a/src/parsing/backend/pest/grammar.pest +++ b/src/parsing/backend/pest/grammar.pest @@ -32,14 +32,20 @@ parameters = { parameter = { ident ~ ":" ~ typ } // Operators -infix = _{ add | subtract | multiply | divide | not_equal | equal | modulo } -add = { "+" } -subtract = { "-" } -multiply = { "*" } -divide = { "/" } -modulo = { "%" } -equal = { "==" } -not_equal = { "!=" } +infix = _{ arithmetic_operator | logical_operator } + +arithmetic_operator = _{ add | subtract | multiply | divide | not_equal | equal | modulo } +add = { "+" } +subtract = { "-" } +multiply = { "*" } +divide = { "/" } +modulo = { "%" } +equal = { "==" } +not_equal = { "!=" } + +logical_operator = _{ and | or } +and = { "&&" } +or = { "||" } prefix = _{ not } not = { "!" } @@ -49,6 +55,7 @@ expr = { prefix? ~ atom ~ (infix ~ prefix? ~ atom)* } atom = _{ call | if_expr | block | literal | ident | "(" ~ expr ~ ")" } block = { "{" ~ statement* ~ expr? ~ "}" } if_expr = { "if" ~ expr ~ block ~ "else" ~ (block | if_expr) } +//tuple = { "(" ~ (expr ~ ",")+ ~ expr ~ ")" } ident = @{ (ASCII_ALPHANUMERIC | "_")+ } typ = _{ ident } diff --git a/src/parsing/backend/pest/mod.rs b/src/parsing/backend/pest/mod.rs index af66de3..5bed115 100644 --- a/src/parsing/backend/pest/mod.rs +++ b/src/parsing/backend/pest/mod.rs @@ -6,8 +6,8 @@ use pest::iterators::Pair; use pest::pratt_parser::PrattParser; use pest::Parser; -use crate::ast::untyped::module::Module; -use crate::ast::untyped::*; +use crate::ast::Module; +use crate::ast::*; use crate::ast::{Import, ModulePath}; use crate::typing::Type; @@ -23,6 +23,9 @@ lazy_static::lazy_static! { // Precedence is defined lowest to highest PrattParser::new() + .op(Op::infix(and, Left)) + .op(Op::infix(or, Left)) + .op(Op::prefix(not)) .op(Op::infix(equal, Left) | Op::infix(not_equal, Left)) .op(Op::infix(add, Left) | Op::infix(subtract, Left)) .op(Op::infix(modulo, Left)) @@ -57,7 +60,9 @@ pub fn parse_module(pair: Pair, path: ModulePath) -> Module { match pair.as_rule() { Rule::definition => { let def = parse_definition(pair.into_inner().next().unwrap()); - module.definitions.push(def); + match def { + Definition::FunctionDefinition(func) => module.functions.push(func), + } } Rule::use_statement => { let path = parse_import(pair.into_inner().next().unwrap()); @@ -82,7 +87,11 @@ fn parse_block(pair: Pair) -> Block { } } - Block { statements, value } + Block { + statements, + value, + typ: Type::Undefined, + } } fn parse_statement(pair: Pair) -> Statement { @@ -92,13 +101,13 @@ fn parse_statement(pair: Pair) -> Statement { let mut pairs = pair.into_inner(); let identifier = pairs.next().unwrap().as_str().to_string(); let expr = parse_expression(pairs.next().unwrap()); - Statement::AssignStatement(identifier, expr) + Statement::AssignStatement(identifier, Box::new(expr)) } Rule::declare_statement => { let mut pairs = pair.into_inner(); let identifier = pairs.next().unwrap().as_str().to_string(); let expr = parse_expression(pairs.next().unwrap()); - Statement::DeclareStatement(identifier, expr) + Statement::DeclareStatement(identifier, Box::new(expr)) } Rule::return_statement => { let expr = if let Some(pair) = pair.into_inner().next() { @@ -110,17 +119,20 @@ fn parse_statement(pair: Pair) -> Statement { } Rule::call_statement => { let call = parse_call(pair.into_inner().next().unwrap()); - Statement::CallStatement(call) + Statement::CallStatement(Box::new(call)) } Rule::use_statement => { let import = parse_import(pair.into_inner().next().unwrap()); - Statement::UseStatement(import) + Statement::UseStatement(Box::new(import)) } Rule::if_statement => { let mut pairs = pair.into_inner(); let condition = parse_expression(pairs.next().unwrap()); let block = parse_block(pairs.next().unwrap()); - Statement::IfStatement(condition, block) + if pairs.next().is_some() { + todo!("implement if-statements with else branch (and else if)") + } + Statement::IfStatement(Box::new(condition), Box::new(block)) } Rule::while_statement => { let mut pairs = pair.into_inner(); @@ -132,8 +144,6 @@ fn parse_statement(pair: Pair) -> Statement { } } -type ImportPath = ModulePath; - fn parse_import(pair: Pair) -> Import { Import(pair.as_str().to_string()) } @@ -141,16 +151,21 @@ fn parse_import(pair: Pair) -> Import { fn parse_call(pair: Pair) -> Call { let mut pairs = pair.into_inner(); // TODO: support calls on more than identifiers (needs grammar change) - let callee = Expr::Identifier(pairs.next().unwrap().as_str().to_string()); + let callee = Expr::Identifier { + name: pairs.next().unwrap().as_str().to_string(), + typ: Type::Undefined, + }; let args: Vec = pairs .next() .unwrap() .into_inner() .map(parse_expression) .collect(); + Call { callee: Box::new(callee), args, + typ: Type::Undefined, } } @@ -170,7 +185,10 @@ fn parse_expression(pair: Pair) -> Expr { .unwrap(), ), Rule::expr => parse_expression(primary), - Rule::ident => Expr::Identifier(primary.as_str().to_string()), + Rule::ident => Expr::Identifier { + name: primary.as_str().to_string(), + typ: Type::Undefined, + }, Rule::call => Expr::Call(Box::new(parse_call(primary))), Rule::block => Expr::Block(Box::new(parse_block(primary))), Rule::if_expr => { @@ -178,11 +196,12 @@ fn parse_expression(pair: Pair) -> Expr { let condition = parse_expression(pairs.next().unwrap()); let true_block = parse_block(pairs.next().unwrap()); let else_value = parse_expression(pairs.next().unwrap()); - Expr::IfExpr( - Box::new(condition), - Box::new(true_block), - Box::new(else_value), - ) + Expr::IfExpr { + cond: Box::new(condition), + then_body: Box::new(true_block), + else_body: Box::new(else_value), + typ: Type::Undefined, + } } Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() { "true" => true, @@ -203,9 +222,26 @@ fn parse_expression(pair: Pair) -> Expr { Rule::modulo => BinaryOperator::Modulo, Rule::equal => BinaryOperator::Equal, Rule::not_equal => BinaryOperator::NotEqual, + Rule::and => BinaryOperator::And, + Rule::or => BinaryOperator::Or, _ => unreachable!(), }; - Expr::BinaryExpression(Box::new(lhs), operator, Box::new(rhs)) + Expr::BinaryExpression { + lhs: Box::new(lhs), + op: operator, + rhs: Box::new(rhs), + typ: Type::Undefined, + } + }) + .map_prefix(|op, inner| { + let operator = match op.as_rule() { + Rule::not => UnaryOperator::Not, + _ => unreachable!(), + }; + Expr::UnaryExpression { + op: operator, + inner: Box::new(inner), + } }) .parse(pairs) } @@ -247,7 +283,7 @@ fn parse_definition(pair: Pair) -> Definition { parameters, return_type, body, - line_col, + location: Location { line_col }, }) } _ => panic!("unexpected node for definition: {:?}", pair.as_rule()), diff --git a/src/parsing/tests.rs b/src/parsing/tests.rs index 7bac09b..634de2a 100644 --- a/src/parsing/tests.rs +++ b/src/parsing/tests.rs @@ -1,12 +1,8 @@ #[test] fn test_addition_function() { + use crate::ast::{expr::Expr, *}; use crate::parsing::backend::pest::parse_as_module; - use crate::{ - ast::untyped::module::Module, - ast::untyped::*, - ast::ModulePath, - typing::Type, - }; + use crate::typing::Type; let source = "fn add(a: int, b: int) int { a + b }"; let path = ModulePath::from("test"); @@ -15,7 +11,7 @@ fn test_addition_function() { let expected_module = Module { file: None, imports: vec![], - definitions: vec![Definition::FunctionDefinition(FunctionDefinition { + functions: vec![FunctionDefinition { name: Identifier::from("add"), parameters: vec![ Parameter { @@ -30,14 +26,22 @@ fn test_addition_function() { return_type: Some(Type::Int), body: Box::new(Block { statements: vec![], - value: Some(Expr::BinaryExpression( - Box::new(Expr::Identifier(Identifier::from("a"))), - BinaryOperator::Add, - Box::new(Expr::Identifier(Identifier::from("b"))), - )), + value: Some(Expr::BinaryExpression { + lhs: Box::new(Expr::Identifier { + name: Identifier::from("a"), + typ: Type::Undefined, + }), + op: BinaryOperator::Add, + rhs: Box::new(Expr::Identifier { + name: Identifier::from("b"), + typ: Type::Undefined, + }), + typ: Type::Undefined, + }), + typ: Type::Undefined, }), - line_col: (1, 1), - })], + location: Location { line_col: (1, 1) }, + }], path, }; diff --git a/src/typing/error.rs b/src/typing/error.rs index 04b98ae..13026c3 100644 --- a/src/typing/error.rs +++ b/src/typing/error.rs @@ -1,4 +1,6 @@ -use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypeContext}; +use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypingContext}; + +use super::UnaryOperator; #[derive(Debug)] pub struct TypeError { @@ -38,7 +40,7 @@ impl TypeError { } impl TypeErrorBuilder { - pub fn context(mut self, ctx: &TypeContext) -> Self { + pub fn context(mut self, ctx: &TypingContext) -> Self { self.file = ctx.file.clone(); self.module = Some(ctx.module.clone()); self.function = ctx.function.clone(); @@ -89,4 +91,8 @@ pub enum TypeErrorKind { WrongFunctionArguments, ConditionIsNotBool, IfElseMismatch, + InvalidUnaryOperator { + operator: UnaryOperator, + inner: Type, + }, } diff --git a/src/typing/mod.rs b/src/typing/mod.rs index 2c442a1..15785a0 100644 --- a/src/typing/mod.rs +++ b/src/typing/mod.rs @@ -1,47 +1,88 @@ use std::collections::HashMap; +use std::fmt::Display; -use crate::ast::untyped::*; -use crate::ast::untyped::module::Module; use crate::ast::ModulePath; +use crate::ast::*; mod error; use crate::typing::error::{TypeError, TypeErrorKind}; #[derive(Debug, PartialEq, Clone)] pub enum Type { + /// Not a real type, used for parsing pass + Undefined, Bool, Int, Float, Unit, Str, + Function { + params: Vec, + returns: Box, + }, Custom(Identifier), } +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Undefined => f.write_str("UNDEFINED"), + Type::Bool => f.write_str("Bool"), + Type::Int => f.write_str("Int"), + Type::Float => f.write_str("Float"), + Type::Unit => f.write_str("Unit"), + Type::Str => f.write_str("Str"), + Type::Custom(identifier) => f.write_str(identifier), + Type::Function { params, returns } => { + f.write_str("Fn(")?; + for param in params { + f.write_fmt(format_args!("{}, ", param))?; + } + f.write_str(") -> ")?; + f.write_fmt(format_args!("{}", returns)) + } + } + } +} + impl From<&str> for Type { fn from(value: &str) -> Self { match value { "int" => Type::Int, "float" => Type::Float, + "bool" => Type::Bool, _ => Type::Custom(Identifier::from(value)), } } } -impl untyped::FunctionDefinition { - fn signature(&self) -> (Vec, Type) { +#[derive(Debug, PartialEq, Clone)] +pub struct Signature(Vec, Type); + +impl Into for Signature { + fn into(self) -> Type { + Type::Function { + params: self.0, + returns: Box::new(self.1), + } + } +} + +impl FunctionDefinition { + fn signature(&self) -> Signature { let return_type = self.return_type.clone().unwrap_or(Type::Unit); let params_types = self.parameters.iter().map(|p| p.typ.clone()).collect(); - (params_types, return_type) + Signature(params_types, return_type) } } impl Module { - pub fn type_check(&self) -> Result<(), TypeError> { - let mut ctx = TypeContext::new(self.path.clone()); + pub fn type_check(&mut self) -> Result<(), TypeError> { + let mut ctx = TypingContext::new(self.path.clone()); ctx.file = self.file.clone(); // Register all function signatures - for Definition::FunctionDefinition(func) in &self.definitions { + for func in &self.functions { if let Some(_previous) = ctx.functions.insert(func.name.clone(), func.signature()) { todo!("handle redefinition of function or identical function names across different files"); } @@ -49,8 +90,8 @@ impl Module { // TODO: add signatures of imported functions (even if they have not been checked) - // Type-check the function bodies - for Definition::FunctionDefinition(func) in &self.definitions { + // Type-check the function bodies and complete all type placeholders + for func in &mut self.functions { func.typ(&mut ctx)?; ctx.variables.clear(); } @@ -59,18 +100,20 @@ impl Module { } } -pub struct TypeContext { +pub struct TypingContext { pub file: Option, pub module: ModulePath, pub function: Option, - pub functions: HashMap, Type)>, + pub functions: HashMap, pub variables: HashMap, } -impl TypeContext { +impl TypingContext { pub fn new(path: ModulePath) -> Self { - let builtin_functions = - HashMap::from([(String::from("println"), (vec![Type::Str], Type::Unit))]); + let builtin_functions = HashMap::from([( + String::from("println"), + Signature(vec![Type::Str], Type::Unit), + )]); Self { file: None, @@ -84,70 +127,72 @@ impl TypeContext { /// Trait for nodes which have a deducible type. pub trait TypeCheck { - /// Try to resolve the type of the node. - fn typ(&self, ctx: &mut TypeContext) -> Result; + /// Try to resolve the type of the node and complete its type placeholders. + fn typ(&mut self, ctx: &mut TypingContext) -> Result; } impl TypeCheck for FunctionDefinition { - fn typ(&self, ctx: &mut TypeContext) -> Result { + fn typ(&mut self, ctx: &mut TypingContext) -> Result { ctx.function = Some(self.name.clone()); for param in &self.parameters { + // XXX: Parameter types should be checked + // when they are not builtin ctx.variables.insert(param.name.clone(), param.typ.clone()); } - let body_type = &self.body.typ(ctx)?; + let body_type = self.body.typ(ctx)?; // If the return type is not specified, it is unit. - let func_return_type = match &self.return_type { - Some(typ) => typ, - None => &Type::Unit, - }; + if self.return_type.is_none() { + self.return_type = Some(Type::Unit) + } // Check coherence with the body's type. - if *func_return_type != *body_type { + if *self.return_type.as_ref().unwrap() != body_type { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType { block_type: body_type.clone(), - function_type: func_return_type.clone(), + function_type: self.return_type.as_ref().unwrap().clone(), }) .build()); } // Check coherence with return statements. - for statement in &self.body.statements { + + for statement in &mut self.body.statements { if let Statement::ReturnStatement(value) = statement { let ret_type = match value { Some(expr) => expr.typ(ctx)?, None => Type::Unit, }; - if ret_type != *func_return_type { + if ret_type != *self.return_type.as_ref().unwrap() { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { - function_type: func_return_type.clone(), - return_type: ret_type, + function_type: self.return_type.as_ref().unwrap().clone(), + return_type: ret_type.clone(), }) .build()); } } } - Ok(func_return_type.clone()) + Ok(self.return_type.clone().unwrap()) } } impl TypeCheck for Block { - fn typ(&self, ctx: &mut TypeContext) -> Result { + fn typ(&mut self, ctx: &mut TypingContext) -> Result { let mut return_typ: Option = None; // Check declarations and assignments. - for statement in &self.statements { + for statement in &mut self.statements { match statement { Statement::DeclareStatement(ident, expr) => { let typ = expr.typ(ctx)?; - if let Some(_typ) = ctx.variables.insert(ident.clone(), typ) { + if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) { // TODO: Shadowing? (illegal for now) return Err(TypeError::builder() .context(ctx) @@ -159,9 +204,9 @@ impl TypeCheck for Block { let rhs_typ = expr.typ(ctx)?; let Some(lhs_typ) = ctx.variables.get(ident) else { return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::AssignUndeclared) - .build()); + .context(ctx) + .kind(TypeErrorKind::AssignUndeclared) + .build()); }; // Ensure same type on both sides. @@ -189,7 +234,7 @@ impl TypeCheck for Block { .build()); } } else { - return_typ = Some(expr_typ); + return_typ = Some(expr_typ.clone()); } } Statement::CallStatement(call) => { @@ -220,9 +265,11 @@ impl TypeCheck for Block { } // Check if there is an expression at the end of the block. - if let Some(expr) = &self.value { - expr.typ(ctx) + if let Some(expr) = &mut self.value { + self.typ = expr.typ(ctx)?.clone(); + Ok(self.typ.clone()) } else { + self.typ = Type::Unit; Ok(Type::Unit) } @@ -234,29 +281,34 @@ impl TypeCheck for Block { } impl TypeCheck for Call { - fn typ(&self, ctx: &mut TypeContext) -> Result { - match &*self.callee { - Expr::Identifier(ident) => { - let signature = match ctx.functions.get(ident) { + fn typ(&mut self, ctx: &mut TypingContext) -> Result { + match &mut *self.callee { + Expr::Identifier { name, typ } => { + let signature = match ctx.functions.get(name) { Some(sgn) => sgn.clone(), None => { return Err(TypeError::builder() .context(ctx) - .kind(TypeErrorKind::UnknownFunctionCalled(ident.clone())) + .kind(TypeErrorKind::UnknownFunctionCalled(name.clone())) .build()) } }; - let (params_types, func_type) = signature; + + *typ = signature.clone().into(); + + let Signature(params_types, func_type) = signature; + + self.typ = func_type.clone(); // Collect arg types. let mut args_types: Vec = vec![]; - for arg in &self.args { - let typ = arg.typ(ctx)?; - args_types.push(typ.clone()); + for arg in &mut self.args { + let arg_typ = arg.typ(ctx)?; + args_types.push(arg_typ.clone()); } if args_types == *params_types { - Ok(func_type.clone()) + Ok(self.typ.clone()) } else { Err(TypeError::builder() .context(ctx) @@ -270,16 +322,17 @@ impl TypeCheck for Call { } impl TypeCheck for Expr { - fn typ(&self, ctx: &mut TypeContext) -> Result { + fn typ(&mut self, ctx: &mut TypingContext) -> Result { match self { - Expr::Identifier(identifier) => { - if let Some(typ) = ctx.variables.get(identifier) { + Expr::Identifier { name, typ } => { + if let Some(ty) = ctx.variables.get(name) { + *typ = ty.clone(); Ok(typ.clone()) } else { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::UnknownIdentifier { - identifier: identifier.clone(), + identifier: name.clone(), }) .build()) } @@ -287,192 +340,107 @@ impl TypeCheck for Expr { Expr::BooleanLiteral(_) => Ok(Type::Bool), Expr::IntegerLiteral(_) => Ok(Type::Int), Expr::FloatLiteral(_) => Ok(Type::Float), - Expr::BinaryExpression(lhs, op, rhs) => match op { - BinaryOperator::Add - | BinaryOperator::Sub - | BinaryOperator::Mul - | BinaryOperator::Div => { - let left_type = &lhs.typ(ctx)?; - let right_type = &rhs.typ(ctx)?; - match (left_type, right_type) { - (Type::Int, Type::Int) => Ok(Type::Int), - (Type::Float, Type::Float) => Ok(Type::Float), - (_, _) => Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidBinaryOperator { - operator: op.clone(), - lht: left_type.clone(), - rht: right_type.clone(), - }) - .build()), - } + Expr::UnaryExpression { op, inner } => { + let inner_type = &inner.typ(ctx)?; + match (&op, inner_type) { + (UnaryOperator::Not, Type::Bool) => Ok(Type::Bool), + _ => Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidUnaryOperator { + operator: *op, + inner: inner_type.clone(), + }) + .build()), } - BinaryOperator::Equal | BinaryOperator::NotEqual => { - let lhs_type = lhs.typ(ctx)?; - let rhs_type = rhs.typ(ctx)?; - if lhs_type != rhs_type { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidBinaryOperator { - operator: op.clone(), - lht: lhs_type.clone(), - rht: rhs_type.clone(), - }) - .build()); + } + Expr::BinaryExpression { lhs, op, rhs, typ } => { + let ty = match op { + BinaryOperator::Add + | BinaryOperator::Sub + | BinaryOperator::Mul + | BinaryOperator::Div + | BinaryOperator::And + | BinaryOperator::Or => { + let left_type = &lhs.typ(ctx)?; + let right_type = &rhs.typ(ctx)?; + match (left_type, right_type) { + (Type::Int, Type::Int) => Ok(Type::Int), + (Type::Float, Type::Float) => Ok(Type::Float), + (Type::Bool, Type::Bool) => Ok(Type::Bool), + (_, _) => Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidBinaryOperator { + operator: op.clone(), + lht: left_type.clone(), + rht: right_type.clone(), + }) + .build()), + } } - Ok(Type::Bool) - } - BinaryOperator::Modulo => { - let lhs_type = lhs.typ(ctx)?; - let rhs_type = lhs.typ(ctx)?; - match (&lhs_type, &rhs_type) { - (Type::Int, Type::Int) => Ok(Type::Int), - _ => Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidBinaryOperator { - operator: op.clone(), - lht: lhs_type.clone(), - rht: rhs_type.clone(), - }) - .build()), + BinaryOperator::Equal | BinaryOperator::NotEqual => { + let lhs_type = lhs.typ(ctx)?; + let rhs_type = rhs.typ(ctx)?; + if lhs_type != rhs_type { + return Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidBinaryOperator { + operator: op.clone(), + lht: lhs_type.clone(), + rht: rhs_type.clone(), + }) + .build()); + } + Ok(Type::Bool) } - } - }, + BinaryOperator::Modulo => { + let lhs_type = lhs.typ(ctx)?; + let rhs_type = lhs.typ(ctx)?; + match (&lhs_type, &rhs_type) { + (Type::Int, Type::Int) => Ok(Type::Int), + _ => Err(TypeError::builder() + .context(ctx) + .kind(TypeErrorKind::InvalidBinaryOperator { + operator: op.clone(), + lht: lhs_type.clone(), + rht: rhs_type.clone(), + }) + .build()), + } + } + }; + *typ = ty?; + Ok(typ.clone()) + } Expr::StringLiteral(_) => Ok(Type::Str), Expr::UnitLiteral => Ok(Type::Unit), Expr::Call(call) => call.typ(ctx), Expr::Block(block) => block.typ(ctx), - Expr::IfExpr(cond, true_block, else_value) => { + Expr::IfExpr { + cond, + then_body, + else_body, + typ, + } => { if cond.typ(ctx)? != Type::Bool { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ConditionIsNotBool) .build()) } else { - let true_block_type = true_block.typ(ctx)?; - let else_type = else_value.typ(ctx)?; - if true_block_type != else_type { + let then_body_type = then_body.typ(ctx)?; + let else_type = else_body.typ(ctx)?; + if then_body_type != else_type { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::IfElseMismatch) .build()) } else { - Ok(true_block_type.clone()) + // XXX: opt: return ref to avoid cloning + *typ = then_body_type.clone(); + Ok(then_body_type) } } } } } } - -struct Typed { - inner: T, - typ: Type, -} - -trait IntoTyped { - fn into_typed(self: Self, ctx: &mut TypeContext) -> Result, TypeError>; -} - -impl IntoTyped for Block { - fn into_typed(self: Block, ctx: &mut TypeContext) -> Result, TypeError> { - let mut return_typ: Option = None; - - // Check declarations and assignments. - for statement in &self.statements { - match statement { - Statement::DeclareStatement(ident, expr) => { - let typ = expr.typ(ctx)?; - if let Some(_typ) = ctx.variables.insert(ident.clone(), typ) { - // XXX: Shadowing? (illegal for now) - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::VariableRedeclaration) - .build()); - } - } - - Statement::AssignStatement(ident, expr) => { - let rhs_typ = expr.typ(ctx)?; - let Some(lhs_typ) = ctx.variables.get(ident) else { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::AssignUndeclared) - .build()); - }; - - // Ensure same type on both sides. - if rhs_typ != *lhs_typ { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::AssignmentMismatch { - lht: lhs_typ.clone(), - rht: rhs_typ.clone(), - }) - .build()); - } - } - - Statement::ReturnStatement(maybe_expr) => { - let expr_typ = if let Some(expr) = maybe_expr { - expr.typ(ctx)? - } else { - Type::Unit - }; - if let Some(typ) = &return_typ { - if expr_typ != *typ { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ReturnStatementsMismatch) - .build()); - } - } else { - return_typ = Some(expr_typ); - } - } - - Statement::CallStatement(call) => { - call.typ(ctx)?; - } - - Statement::UseStatement(_path) => { - // TODO: import the signatures (and types) - } - - Statement::IfStatement(cond, block) => { - if cond.typ(ctx)? != Type::Bool { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ConditionIsNotBool) - .build()); - } - block.typ(ctx)?; - } - - Statement::WhileStatement(cond, block) => { - if cond.typ(ctx)? != Type::Bool { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ConditionIsNotBool) - .build()); - } - block.typ(ctx)?; - } - } - } - - // Check if there is an expression at the end of the block. - let typ = if let Some(expr) = &self.value { - expr.typ(ctx)? - } else { - Type::Unit - }; - - Ok(Typed { inner: self, typ }) - - // TODO/FIXME: find a way to return `return_typ` so that the - // top-level block (the function) can check if this return type - // (and eventually those from other block) matches the type of - // the function. - } -}