use std::collections::HashMap; use std::fmt::Display; use crate::ast::ModulePath; use crate::ast::*; mod error; use crate::typing::error::{TypeError, TypeErrorKind}; #[derive(Debug, PartialEq, Clone)] pub enum Type { /// Not a real type, used for parsing pass Undefined, Bool, Int, Float, Unit, Str, Function { params: Vec, returns: Box, }, Custom(Identifier), } impl Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Undefined => f.write_str("UNDEFINED"), Type::Bool => f.write_str("Bool"), Type::Int => f.write_str("Int"), Type::Float => f.write_str("Float"), Type::Unit => f.write_str("Unit"), Type::Str => f.write_str("Str"), Type::Custom(identifier) => f.write_str(identifier), Type::Function { params, returns } => { f.write_str("Fn(")?; for param in params { f.write_fmt(format_args!("{}, ", param))?; } f.write_str(") -> ")?; f.write_fmt(format_args!("{}", returns)) } } } } impl From<&str> for Type { fn from(value: &str) -> Self { match value { "int" => Type::Int, "float" => Type::Float, "bool" => Type::Bool, _ => Type::Custom(Identifier::from(value)), } } } #[derive(Debug, PartialEq, Clone)] pub struct Signature(Vec, Type); impl Into for Signature { fn into(self) -> Type { Type::Function { params: self.0, returns: Box::new(self.1), } } } impl FunctionDefinition { fn signature(&self) -> Signature { let return_type = self.return_type.clone().unwrap_or(Type::Unit); let params_types = self.parameters.iter().map(|p| p.typ.clone()).collect(); Signature(params_types, return_type) } } impl Module { pub fn type_check(&mut self) -> Result<(), TypeError> { let mut ctx = TypingContext::new(self.path.clone()); ctx.file = self.file.clone(); // Register all function signatures for func in &self.functions { if let Some(_previous) = ctx.functions.insert(func.name.clone(), func.signature()) { todo!("handle redefinition of function or identical function names across different files"); } } // TODO: add signatures of imported functions (even if they have not been checked) // Type-check the function bodies and complete all type placeholders for func in &mut self.functions { func.typ(&mut ctx)?; ctx.variables.clear(); } Ok(()) } } pub struct TypingContext { pub file: Option, pub module: ModulePath, pub function: Option, pub functions: HashMap, pub variables: HashMap, } impl TypingContext { pub fn new(path: ModulePath) -> Self { let builtin_functions = HashMap::from([( String::from("println"), Signature(vec![Type::Str], Type::Unit), )]); Self { file: None, module: path, function: None, functions: builtin_functions, variables: Default::default(), } } } /// Trait for nodes which have a deducible type. pub trait TypeCheck { /// Try to resolve the type of the node and complete its type placeholders. fn typ(&mut self, ctx: &mut TypingContext) -> Result; } impl TypeCheck for FunctionDefinition { fn typ(&mut self, ctx: &mut TypingContext) -> Result { ctx.function = Some(self.name.clone()); for param in &self.parameters { // XXX: Parameter types should be checked // when they are not builtin ctx.variables.insert(param.name.clone(), param.typ.clone()); } let body_type = self.body.typ(ctx)?; // If the return type is not specified, it is unit. if self.return_type.is_none() { self.return_type = Some(Type::Unit) } // Check coherence with the body's type. if *self.return_type.as_ref().unwrap() != body_type { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType { block_type: body_type.clone(), function_type: self.return_type.as_ref().unwrap().clone(), }) .build()); } // Check coherence with return statements. for statement in &mut self.body.statements { if let Statement::ReturnStatement(value) = statement { let ret_type = match value { Some(expr) => expr.typ(ctx)?, None => Type::Unit, }; if ret_type != *self.return_type.as_ref().unwrap() { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { function_type: self.return_type.as_ref().unwrap().clone(), return_type: ret_type.clone(), }) .build()); } } } Ok(self.return_type.clone().unwrap()) } } impl TypeCheck for Block { fn typ(&mut self, ctx: &mut TypingContext) -> Result { let mut return_typ: Option = None; // Check declarations and assignments. for statement in &mut self.statements { match statement { Statement::DeclareStatement(ident, expr) => { let typ = expr.typ(ctx)?; if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) { // TODO: Shadowing? (illegal for now) return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::VariableRedeclaration) .build()); } } Statement::AssignStatement(ident, expr) => { let rhs_typ = expr.typ(ctx)?; let Some(lhs_typ) = ctx.variables.get(ident) else { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::AssignUndeclared) .build()); }; // Ensure same type on both sides. if rhs_typ != *lhs_typ { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::AssignmentMismatch { lht: lhs_typ.clone(), rht: rhs_typ.clone(), }) .build()); } } Statement::ReturnStatement(maybe_expr) => { let expr_typ = if let Some(expr) = maybe_expr { expr.typ(ctx)? } else { Type::Unit }; if let Some(typ) = &return_typ { if expr_typ != *typ { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ReturnStatementsMismatch) .build()); } } else { return_typ = Some(expr_typ.clone()); } } Statement::CallStatement(call) => { call.typ(ctx)?; } Statement::UseStatement(_path) => { // TODO: import the signatures (and types) } Statement::IfStatement(cond, block) => { if cond.typ(ctx)? != Type::Bool { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ConditionIsNotBool) .build()); } block.typ(ctx)?; } Statement::WhileStatement(cond, block) => { if cond.typ(ctx)? != Type::Bool { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ConditionIsNotBool) .build()); } block.typ(ctx)?; } } } // Check if there is an expression at the end of the block. if let Some(expr) = &mut self.value { self.typ = expr.typ(ctx)?.clone(); Ok(self.typ.clone()) } else { self.typ = Type::Unit; Ok(Type::Unit) } // TODO/FIXME: find a way to return `return_typ` so that the // top-level block (the function) can check if this return type // (and eventually those from other block) matches the type of // the function. } } impl TypeCheck for Call { fn typ(&mut self, ctx: &mut TypingContext) -> Result { match &mut *self.callee { Expr::Identifier { name, typ } => { let signature = match ctx.functions.get(name) { Some(sgn) => sgn.clone(), None => { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::UnknownFunctionCalled(name.clone())) .build()) } }; *typ = signature.clone().into(); let Signature(params_types, func_type) = signature; self.typ = func_type.clone(); // Collect arg types. let mut args_types: Vec = vec![]; for arg in &mut self.args { let arg_typ = arg.typ(ctx)?; args_types.push(arg_typ.clone()); } if args_types == *params_types { Ok(self.typ.clone()) } else { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::WrongFunctionArguments) .build()) } } _ => unimplemented!("cannot call on expression other than identifier"), } } } impl TypeCheck for Expr { fn typ(&mut self, ctx: &mut TypingContext) -> Result { match self { Expr::Identifier { name, typ } => { if let Some(ty) = ctx.variables.get(name) { *typ = ty.clone(); Ok(typ.clone()) } else { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::UnknownIdentifier { identifier: name.clone(), }) .build()) } } Expr::BooleanLiteral(_) => Ok(Type::Bool), Expr::IntegerLiteral(_) => Ok(Type::Int), Expr::FloatLiteral(_) => Ok(Type::Float), Expr::UnaryExpression { op, inner } => { let inner_type = &inner.typ(ctx)?; match (&op, inner_type) { (UnaryOperator::Not, Type::Bool) => Ok(Type::Bool), _ => Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::InvalidUnaryOperator { operator: *op, inner: inner_type.clone(), }) .build()), } } Expr::BinaryExpression { lhs, op, rhs, typ } => { let ty = match op { BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul | BinaryOperator::Div | BinaryOperator::And | BinaryOperator::Or => { let left_type = &lhs.typ(ctx)?; let right_type = &rhs.typ(ctx)?; match (left_type, right_type) { (Type::Int, Type::Int) => Ok(Type::Int), (Type::Float, Type::Float) => Ok(Type::Float), (Type::Bool, Type::Bool) => Ok(Type::Bool), (_, _) => Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::InvalidBinaryOperator { operator: op.clone(), lht: left_type.clone(), rht: right_type.clone(), }) .build()), } } BinaryOperator::Equal | BinaryOperator::NotEqual => { let lhs_type = lhs.typ(ctx)?; let rhs_type = rhs.typ(ctx)?; if lhs_type != rhs_type { return Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::InvalidBinaryOperator { operator: op.clone(), lht: lhs_type.clone(), rht: rhs_type.clone(), }) .build()); } Ok(Type::Bool) } BinaryOperator::Modulo => { let lhs_type = lhs.typ(ctx)?; let rhs_type = lhs.typ(ctx)?; match (&lhs_type, &rhs_type) { (Type::Int, Type::Int) => Ok(Type::Int), _ => Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::InvalidBinaryOperator { operator: op.clone(), lht: lhs_type.clone(), rht: rhs_type.clone(), }) .build()), } } }; *typ = ty?; Ok(typ.clone()) } Expr::StringLiteral(_) => Ok(Type::Str), Expr::UnitLiteral => Ok(Type::Unit), Expr::Call(call) => call.typ(ctx), Expr::Block(block) => block.typ(ctx), Expr::IfExpr { cond, then_body, else_body, typ, } => { if cond.typ(ctx)? != Type::Bool { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::ConditionIsNotBool) .build()) } else { let then_body_type = then_body.typ(ctx)?; let else_type = else_body.typ(ctx)?; if then_body_type != else_type { Err(TypeError::builder() .context(ctx) .kind(TypeErrorKind::IfElseMismatch) .build()) } else { // XXX: opt: return ref to avoid cloning *typ = then_body_type.clone(); Ok(then_body_type) } } } } } }