From f415c4abbe4a1aec23fbf02d012166026031a437 Mon Sep 17 00:00:00 2001 From: Romain Paquet Date: Wed, 3 Jul 2024 19:59:12 +0200 Subject: [PATCH] add pretty diagnostics --- Cargo.toml | 13 +- src/ast/expr.rs | 53 +-- src/ast/mod.rs | 105 +++++- src/jit/mod.rs | 159 +++++---- src/main.rs | 48 ++- src/parsing/backend/pest/mod.rs | 580 +++++++++++++++++++------------- src/parsing/mod.rs | 16 +- src/parsing/tests.rs | 73 +++- src/source.rs | 22 ++ src/typing/error.rs | 205 +++++++---- src/typing/mod.rs | 323 +++++++++--------- src/typing/tests.rs | 43 ++- 12 files changed, 1037 insertions(+), 603 deletions(-) create mode 100644 src/source.rs diff --git a/Cargo.toml b/Cargo.toml index 9c62ddd..438ac82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,10 +5,15 @@ edition = "2021" [dependencies] clap = { version = "4.5.7", features = ["derive"] } -cranelift = "0.108.1" -cranelift-jit = "0.108.1" -cranelift-module = "0.108.1" -cranelift-native = "0.108.1" +cranelift = "0.109.0" +cranelift-jit = "0.109.0" +cranelift-module = "0.109.0" +cranelift-native = "0.109.0" lazy_static = "1.4.0" pest = "2.7.4" pest_derive = "2.7.4" +ariadne = "0.4.1" +anyhow = "1.0.86" + +[dev-dependencies] +pretty_assertions = "1.4.0" diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 90ffc48..71ec649 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -1,17 +1,27 @@ use crate::ast::*; use crate::typing::Type; +#[derive(Debug, PartialEq)] +pub struct SExpr { + pub expr: Expr, + pub span: Span, +} + +#[derive(Debug, PartialEq)] +pub struct BinaryExpression { + pub lhs: Box, + pub op: BinaryOperator, + pub op_span: Span, + pub rhs: Box, + pub typ: Type, +} + #[derive(Debug, PartialEq)] pub enum Expr { - BinaryExpression { - lhs: Box, - op: BinaryOperator, - rhs: Box, - typ: Type, - }, + BinaryExpression(BinaryExpression), UnaryExpression { op: UnaryOperator, - inner: Box, + inner: Box, }, Identifier { name: String, @@ -21,9 +31,9 @@ pub enum Expr { Block(Box), /// Last field is either Expr::Block or Expr::IfExpr IfExpr { - cond: Box, + cond: Box, then_body: Box, - else_body: Box, + else_body: Box, typ: Type, }, // Literals @@ -45,22 +55,12 @@ impl Block { impl Expr { pub fn ty(&self) -> Type { match self { - Expr::BinaryExpression { - lhs: _, - op: _, - rhs: _, - typ, - } => typ.clone(), - Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here - Expr::Identifier { name: _, typ } => typ.clone(), + Expr::BinaryExpression(BinaryExpression { typ, .. }) => typ.clone(), + Expr::UnaryExpression { inner, .. } => inner.ty(), // XXX: problems will arise here + Expr::Identifier { typ, .. } => typ.clone(), Expr::Call(call) => call.typ.clone(), Expr::Block(block) => block.typ.clone(), - Expr::IfExpr { - cond: _, - then_body: _, - else_body: _, - typ, - } => typ.clone(), + Expr::IfExpr { typ, .. } => typ.clone(), Expr::UnitLiteral => Type::Unit, Expr::BooleanLiteral(_) => Type::Bool, Expr::IntegerLiteral(_) => Type::Int, @@ -69,3 +69,10 @@ impl Expr { } } } + +impl SExpr { + #[inline] + pub fn ty(&self) -> Type { + self.expr.ty() + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 3355c10..6da7ed3 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,9 +1,11 @@ pub mod expr; -pub use expr::Expr; +pub use expr::{BinaryExpression, Expr, SExpr}; use crate::typing::Type; -use std::path::Path; + +use ariadne; +use std::{fmt::Display, path::Path}; #[derive(Debug, PartialEq, Clone)] pub enum BinaryOperator { @@ -20,6 +22,22 @@ pub enum BinaryOperator { NotEqual, } +impl Display for BinaryOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + BinaryOperator::And => "&&", + BinaryOperator::Or => "||", + BinaryOperator::Add => "+", + BinaryOperator::Sub => "-", + BinaryOperator::Mul => "*", + BinaryOperator::Div => "/", + BinaryOperator::Modulo => "%", + BinaryOperator::Equal => "==", + BinaryOperator::NotEqual => "!=", + }) + } +} + #[derive(Debug, PartialEq, Copy, Clone)] pub enum UnaryOperator { Not, @@ -27,12 +45,32 @@ pub enum UnaryOperator { pub type Identifier = String; -#[derive(Debug, PartialEq)] -pub struct Location { - pub line_col: (usize, usize), +pub type SourceId = u32; + +#[derive(Debug, PartialEq, Clone, Copy)] +pub struct Span { + pub source: SourceId, + pub start: usize, + pub end: usize, } -#[derive(Debug, PartialEq, Clone, Default)] +impl ariadne::Span for Span { + type SourceId = SourceId; + + fn source(&self) -> &Self::SourceId { + &self.source + } + + fn start(&self) -> usize { + self.start + } + + fn end(&self) -> usize { + self.end + } +} + +#[derive(Debug, PartialEq, Clone, Default, Eq, Hash)] pub struct ModulePath { components: Vec, } @@ -65,7 +103,7 @@ impl From<&Path> for ModulePath { .map(|component| match component { std::path::Component::Normal(n) => { if meta.is_file() { - n.to_str().unwrap().split(".").nth(0).unwrap().to_string() + n.to_str().unwrap().split('.').nth(0).unwrap().to_string() } else if meta.is_dir() { n.to_str().unwrap().to_string() } else { @@ -91,22 +129,51 @@ impl From<&str> for ModulePath { #[derive(Eq, PartialEq, Debug)] pub struct Import(pub String); +#[derive(Debug, PartialEq)] +pub struct ReturnStatement { + pub expr: Option, + pub span: Span, +} + #[derive(Debug, PartialEq)] pub enum Statement { - DeclareStatement(Identifier, Box), - AssignStatement(Identifier, Box), - ReturnStatement(Option), - CallStatement(Box), - UseStatement(Box), - IfStatement(Box, Box), - WhileStatement(Box, Box), + DeclareStatement { + lhs: Identifier, + rhs: Box, + span: Span, + }, + AssignStatement { + lhs: Identifier, + rhs: Box, + span: Span, + }, + ReturnStatement(ReturnStatement), + CallStatement { + call: Box, + span: Span, + }, + UseStatement { + import: Box, + span: Span, + }, + IfStatement { + condition: Box, + then_block: Box, + span: Span, + }, + WhileStatement { + condition: Box, + loop_block: Box, + span: Span, + }, } #[derive(Debug, PartialEq)] pub struct Block { pub statements: Vec, - pub value: Option, + pub value: Option, pub typ: Type, + pub span: Option, } impl Block { @@ -115,6 +182,7 @@ impl Block { typ: Type::Unit, statements: Vec::with_capacity(0), value: None, + span: None, } } } @@ -129,8 +197,9 @@ pub struct FunctionDefinition { pub name: Identifier, pub parameters: Vec, pub return_type: Option, + pub return_type_span: Option, pub body: Box, - pub location: Location, + pub span: Span, } #[derive(Debug, PartialEq, Default)] @@ -159,8 +228,8 @@ impl Module { #[derive(Debug, PartialEq)] pub struct Call { - pub callee: Box, - pub args: Vec, + pub callee: Box, + pub args: Vec, pub typ: Type, } diff --git a/src/jit/mod.rs b/src/jit/mod.rs index eff74bc..bf86ccc 100644 --- a/src/jit/mod.rs +++ b/src/jit/mod.rs @@ -1,15 +1,17 @@ use crate::{ ast::{ - self, BinaryOperator, ModulePath, UnaryOperator, - {expr::Expr, FunctionDefinition, Statement}, + self, expr::BinaryExpression, BinaryOperator, Expr, FunctionDefinition, ModulePath, + ReturnStatement, SourceId, Statement, UnaryOperator, }, - parsing, - typing::{CheckedModule, Type}, + parsing::{DefaultParser, Parser}, + typing::Type, + SourceCache, }; +use ariadne::Cache as _; use cranelift::{codegen::ir::UserFuncName, prelude::*}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module}; -use std::{collections::HashMap, fs, ops::Deref}; +use std::collections::HashMap; /// The basic JIT class. pub struct JIT { @@ -30,6 +32,9 @@ pub struct JIT { /// Whether to print CLIR during compilation pub dump_clir: bool, + + /// Parser used to build the AST + pub parser: DefaultParser, } impl Default for JIT { @@ -59,18 +64,34 @@ impl Default for JIT { data_desc: DataDescription::new(), module, dump_clir: false, + parser: DefaultParser::default(), } } } impl JIT { /// Compile source code into machine code. - pub fn compile(&mut self, input: &str, namespace: ModulePath) -> Result<*const u8, String> { + pub fn compile( + &mut self, + input: &str, + namespace: ModulePath, + id: SourceId, + ) -> Result<*const u8, String> { + let mut source_cache = (0u32, ariadne::Source::from(input)); + // Parse the source code into an AST - let ast = parsing::parse_as_module(input, namespace) - .map_err(|x| format!("Parsing error: {x}"))? - .type_check() - .map_err(|x| format!("Typing error: {x}"))?; + let mut ast = self + .parser + .parse_as_module(input, namespace, id) + .map_err(|x| format!("Parsing error: {x}"))?; + + ast.type_check() + .map_err(|errors| { + errors + .iter() + .for_each(|e| e.to_report(&ast).eprint(&mut source_cache).unwrap()); + }) + .unwrap(); // Translate the AST into Cranelift IR self.translate(&ast)?; @@ -89,20 +110,24 @@ impl JIT { } } - pub fn compile_file(&mut self, path: &str) -> Result<*const u8, String> { + pub fn compile_file( + &mut self, + path: &str, + id: SourceId, + source_cache: &mut SourceCache, + ) -> Result<*const u8, String> { self.compile( - fs::read_to_string(path) - .map_err(|x| format!("Cannot open {}: {}", path, x))? - .as_str(), + source_cache + .fetch(&id) + .map(|s| s.text()) + .map_err(|e| format!("{:?}", e))?, AsRef::::as_ref(path).into(), + id, ) } /// Translate language AST into Cranelift IR. - fn translate(&mut self, ast: &CheckedModule) -> Result<(), String> { - // Dump contract-holding wrapper type - let ast = &ast.0; - + fn translate(&mut self, ast: &ast::Module) -> Result<(), String> { let mut signatures: Vec = Vec::with_capacity(ast.functions.len()); let mut func_ids: Vec = Vec::with_capacity(ast.functions.len()); @@ -199,7 +224,7 @@ impl JIT { // Emit the final return instruction. if let Some(return_expr) = &function.body.value { - let return_value = translator.translate_expr(&return_expr); + let return_value = translator.translate_expr(&return_expr.expr); translator.builder.ins().return_(&[return_value]); } else { translator.builder.ins().return_(&[]); @@ -234,18 +259,26 @@ struct FunctionTranslator<'a> { impl<'a> FunctionTranslator<'a> { fn translate_statement(&mut self, stmt: &Statement) -> Option { match stmt { - Statement::AssignStatement(name, expr) => { + Statement::AssignStatement { + lhs: name, + rhs: expr, + .. + } => { // `def_var` is used to write the value of a variable. Note that // variables can have multiple definitions. Cranelift will // convert them into SSA form for itself automatically. - let new_value = self.translate_expr(expr); + let new_value = self.translate_expr(&expr.expr); let variable = self.variables.get(name).unwrap(); self.builder.def_var(*variable, new_value); Some(new_value) } - Statement::DeclareStatement(name, expr) => { - let value = self.translate_expr(expr); + Statement::DeclareStatement { + lhs: name, + rhs: expr, + .. + } => { + let value = self.translate_expr(&expr.expr); let variable = Variable::from_u32(self.variables.len() as u32); self.builder .declare_var(variable, self.translate_type(&expr.ty())); @@ -254,10 +287,12 @@ impl<'a> FunctionTranslator<'a> { Some(value) } - Statement::ReturnStatement(maybe_expr) => { + Statement::ReturnStatement(ReturnStatement { + expr: maybe_expr, .. + }) => { // TODO: investigate tail call let values = if let Some(expr) = maybe_expr { - vec![self.translate_expr(expr)] + vec![self.translate_expr(&expr.expr)] } else { // XXX: urgh Vec::with_capacity(0) @@ -269,12 +304,16 @@ impl<'a> FunctionTranslator<'a> { None } - Statement::CallStatement(call) => self.translate_call(call), + Statement::CallStatement { call, .. } => self.translate_call(call), - Statement::UseStatement(_) => todo!(), + Statement::UseStatement { .. } => todo!(), - Statement::IfStatement(cond, then_body) => { - let condition_value = self.translate_expr(cond); + Statement::IfStatement { + condition: cond, + then_block: then_body, + .. + } => { + let condition_value = self.translate_expr(&cond.expr); let then_block = self.builder.create_block(); let merge_block = self.builder.create_block(); @@ -294,7 +333,7 @@ impl<'a> FunctionTranslator<'a> { None } - Statement::WhileStatement(_, _) => todo!(), + Statement::WhileStatement { .. } => todo!(), } } @@ -302,11 +341,11 @@ impl<'a> FunctionTranslator<'a> { match expr { Expr::UnitLiteral => unreachable!(), - Expr::BooleanLiteral(imm) => self.builder.ins().iconst(types::I8, i64::from(*imm)), - Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm), - Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm), + Expr::BooleanLiteral(imm, ..) => self.builder.ins().iconst(types::I8, i64::from(*imm)), + Expr::IntegerLiteral(imm, ..) => self.builder.ins().iconst(types::I32, *imm), + Expr::FloatLiteral(imm, ..) => self.builder.ins().f64const(*imm), - Expr::StringLiteral(s) => { + Expr::StringLiteral(s, ..) => { let id = self.module.declare_anonymous_data(false, false).unwrap(); let bytes: Box<[u8]> = s.as_bytes().into(); self.data_desc.define(bytes); @@ -318,14 +357,9 @@ impl<'a> FunctionTranslator<'a> { .global_value(self.module.isa().pointer_type(), gv) } - Expr::BinaryExpression { - lhs, - op, - rhs, - typ: _, - } => { - let lhs_value = self.translate_expr(lhs); - let rhs_value = self.translate_expr(rhs); + Expr::BinaryExpression(BinaryExpression { lhs, rhs, op, .. }) => { + let lhs_value = self.translate_expr(&lhs.expr); + let rhs_value = self.translate_expr(&rhs.expr); match (lhs.ty(), lhs.ty()) { (Type::Int, Type::Int) => match op { @@ -361,8 +395,9 @@ impl<'a> FunctionTranslator<'a> { then_body, else_body, typ, + .. } => { - let condition_value = self.translate_expr(cond); + let condition_value = self.translate_expr(&cond.expr); let then_block = self.builder.create_block(); let else_block = self.builder.create_block(); @@ -384,11 +419,11 @@ impl<'a> FunctionTranslator<'a> { self.builder.switch_to_block(then_block); self.builder.seal_block(then_block); for stmt in &then_body.statements { - self.translate_statement(&stmt); + self.translate_statement(stmt); } let then_return_value = match &then_body.value { - Some(val) => vec![self.translate_expr(val)], + Some(val) => vec![self.translate_expr(&val.expr)], None => Vec::with_capacity(0), }; @@ -399,9 +434,9 @@ impl<'a> FunctionTranslator<'a> { self.builder.seal_block(else_block); // XXX: the else can be just an expression: do we always need to // make a second branch in that case? Or leave it to cranelift? - let else_return_value = match **else_body { + let else_return_value = match else_body.expr { Expr::UnitLiteral => Vec::with_capacity(0), - _ => vec![self.translate_expr(else_body)], + _ => vec![self.translate_expr(&else_body.expr)], }; // Jump to the merge block, passing it the block return value. @@ -420,8 +455,8 @@ impl<'a> FunctionTranslator<'a> { phi } - Expr::UnaryExpression { op, inner } => { - let inner_value = self.translate_expr(inner); + Expr::UnaryExpression { op, inner, .. } => { + let inner_value = self.translate_expr(&inner.expr); match op { // XXX: This should not be a literal translation UnaryOperator::Not => { @@ -431,13 +466,13 @@ impl<'a> FunctionTranslator<'a> { } } - Expr::Identifier { name, typ: _ } => { + Expr::Identifier { name, .. } => { self.builder.use_var(*self.variables.get(name).unwrap()) } - Expr::Call(call) => self.translate_call(call).unwrap(), + Expr::Call(call, ..) => self.translate_call(call).unwrap(), - Expr::Block(block) => self.translate_block(block).unwrap(), + Expr::Block(block, ..) => self.translate_block(block).unwrap(), } } @@ -445,16 +480,16 @@ impl<'a> FunctionTranslator<'a> { for stmt in &block.statements { self.translate_statement(stmt); } - if let Some(block_value) = &block.value { - Some(self.translate_expr(block_value)) - } else { - None - } + + block + .value + .as_ref() + .map(|block_value| self.translate_expr(&block_value.expr)) } fn translate_call(&mut self, call: &ast::Call) -> Option { - match call.callee.deref() { - Expr::Identifier { name, typ: _ } => { + match &call.callee.expr { + Expr::Identifier { name, .. } => { let func_ref = if let Some(func_or_data_id) = self.module.get_name(name.as_ref()) { if let FuncOrDataId::Func(func_id) = func_or_data_id { self.module.declare_func_in_func(func_id, self.builder.func) @@ -465,7 +500,11 @@ impl<'a> FunctionTranslator<'a> { todo!() }; - let args: Vec = call.args.iter().map(|a| self.translate_expr(a)).collect(); + let args: Vec = call + .args + .iter() + .map(|a| self.translate_expr(&a.expr)) + .collect(); let call_inst = self.builder.ins().call(func_ref, &args); let results = self.builder.inst_results(call_inst); diff --git a/src/main.rs b/src/main.rs index befe80b..9e0bb6f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,20 @@ pub mod ast; pub mod jit; pub mod parsing; +pub mod source; pub mod typing; -use clap::{Parser, Subcommand}; +use std::default::Default; +use std::path::PathBuf; + +use clap::{Parser as ClapParser, Subcommand}; use crate::ast::Module; +use crate::parsing::Parser; +use crate::source::SourceCache; /// Experimental compiler for lila -#[derive(Parser, Debug)] +#[derive(ClapParser, Debug)] #[command(author = "Romain P. ")] #[command(version, about, long_about = None)] struct Cli { @@ -51,21 +57,27 @@ enum Commands { }, } -fn parse(files: &Vec) -> Vec { +fn parse(files: &[String]) -> Vec { + let mut parser = parsing::DefaultParser::default(); let paths = files.iter().map(std::path::Path::new); paths - .map(|path| match parsing::parse_file(&path) { + .enumerate() + .map(|(i, path)| match parser.parse_file(path, i as u32) { Ok(module) => module, Err(e) => panic!("Parsing error: {:#?}", e), }) .collect() } -fn check(modules: &mut Vec) { - while let Some(module) = modules.pop() { - if let Err(e) = module.type_check() { - eprintln!("{}", e); - return; +fn check(modules: &mut Vec, source_cache: &mut SourceCache) { + for module in modules { + if let Err(errors) = module.type_check() { + for error in errors { + error + .to_report(module) + .eprint(&mut *source_cache) + .expect("cannot write error to stderr"); + } } } } @@ -83,20 +95,32 @@ fn main() { } println!("Parsing OK"); } + Commands::TypeCheck { files, dump_ast } => { + let mut source_cache = SourceCache { + paths: files.iter().map(PathBuf::from).collect(), + file_cache: ariadne::FileCache::default(), + }; let mut modules = parse(files); - check(&mut modules); + check(&mut modules, &mut source_cache); if *dump_ast { for module in &modules { println!("{:#?}", &module); } } } + Commands::Compile { files, dump_clir } | Commands::Run { files, dump_clir } => { + let mut source_cache = SourceCache { + paths: files.iter().map(PathBuf::from).collect(), + file_cache: ariadne::FileCache::default(), + }; + let mut jit = jit::JIT::default(); jit.dump_clir = *dump_clir; - for file in files { - match jit.compile_file(file) { + + for (id, file) in files.iter().enumerate() { + match jit.compile_file(file, id as u32, &mut source_cache) { Err(e) => eprintln!("{}", e), Ok(code) => { println!("Compiled {}", file); diff --git a/src/parsing/backend/pest/mod.rs b/src/parsing/backend/pest/mod.rs index 5bed115..82ba471 100644 --- a/src/parsing/backend/pest/mod.rs +++ b/src/parsing/backend/pest/mod.rs @@ -1,14 +1,13 @@ use std::fs; use std::path::Path; -use pest::error::Error; +use expr::BinaryExpression; use pest::iterators::Pair; use pest::pratt_parser::PrattParser; -use pest::Parser; +use pest::Parser as PestParser; +use ReturnStatement; -use crate::ast::Module; use crate::ast::*; -use crate::ast::{Import, ModulePath}; use crate::typing::Type; #[derive(pest_derive::Parser)] @@ -33,259 +32,358 @@ lazy_static::lazy_static! { }; } -pub fn parse_file(path: &Path) -> Result> { - let source = fs::read_to_string(&path).expect("could not read source file"); - let module_path = ModulePath::from(path); - let mut module = parse_as_module(&source, module_path)?; - module.file = Some(path.to_owned()); - Ok(module) +#[derive(Default)] +pub struct Parser { + source: SourceId, } -pub fn parse_as_module(source: &str, path: ModulePath) -> Result> { - let mut pairs = LilaParser::parse(Rule::source_file, &source)?; +impl crate::parsing::Parser for Parser { + fn parse_file(&mut self, path: &Path, id: SourceId) -> anyhow::Result { + let source = fs::read_to_string(path)?; + let module_path = ModulePath::from(path); + let mut module = self.parse_as_module(&source, module_path, id)?; + module.file = Some(path.to_owned()); + Ok(module) + } - assert!(pairs.len() == 1); - let module = parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path); + fn parse_as_module( + &mut self, + source: &str, + path: ModulePath, + id: SourceId, + ) -> anyhow::Result { + self.source = id; + let mut pairs = LilaParser::parse(Rule::source_file, source)?; - Ok(module) + assert!(pairs.len() == 1); + let module = self.parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path); + + Ok(module) + } } -pub fn parse_module(pair: Pair, path: ModulePath) -> Module { - assert!(pair.as_rule() == Rule::module_items); +impl Parser { + fn parse_module(&self, pair: Pair, path: ModulePath) -> Module { + assert!(pair.as_rule() == Rule::module_items); - let mut module = Module::new(path); + let mut module = Module::new(path); + + let pairs = pair.into_inner(); + for pair in pairs { + match pair.as_rule() { + Rule::definition => { + let def = self.parse_definition(pair.into_inner().next().unwrap()); + match def { + Definition::FunctionDefinition(func) => module.functions.push(func), + } + } + Rule::use_statement => { + let path = self.parse_import(pair.into_inner().next().unwrap()); + module.imports.push(path); + } + _ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()), + } + } + + module + } + + fn parse_block(&self, pair: Pair) -> Block { + let mut statements = vec![]; + let mut value = None; + let span = self.make_span(&pair); + + for pair in pair.into_inner() { + match pair.as_rule() { + Rule::statement => statements.push(self.parse_statement(pair)), + Rule::expr => value = Some(self.parse_expression(pair)), + _ => panic!("unexpected rule {:?} in block", pair.as_rule()), + } + } + + Block { + statements, + value, + typ: Type::Undefined, + span: Some(span), + } + } + + fn parse_statement(&self, pair: Pair) -> Statement { + let pair = pair.into_inner().next().unwrap(); + let span = self.make_span(&pair); - let pairs = pair.into_inner(); - for pair in pairs { match pair.as_rule() { - Rule::definition => { - let def = parse_definition(pair.into_inner().next().unwrap()); - match def { - Definition::FunctionDefinition(func) => module.functions.push(func), + Rule::assign_statement => { + let mut pairs = pair.into_inner(); + let identifier = pairs.next().unwrap().as_str().to_string(); + let expr = self.parse_expression(pairs.next().unwrap()); + Statement::AssignStatement { + lhs: identifier, + rhs: Box::new(expr), + span, + } + } + Rule::declare_statement => { + let mut pairs = pair.into_inner(); + let identifier = pairs.next().unwrap().as_str().to_string(); + let expr = self.parse_expression(pairs.next().unwrap()); + Statement::DeclareStatement { + lhs: identifier, + rhs: Box::new(expr), + span, + } + } + Rule::return_statement => { + let expr = pair + .into_inner() + .next() + .map(|expr| self.parse_expression(expr)); + Statement::ReturnStatement(ReturnStatement { expr, span }) + } + Rule::call_statement => { + let call = self.parse_call(pair.into_inner().next().unwrap()); + Statement::CallStatement { + call: Box::new(call), + span, } } Rule::use_statement => { - let path = parse_import(pair.into_inner().next().unwrap()); - module.imports.push(path); - } - _ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()), - } - } - - module -} - -fn parse_block(pair: Pair) -> Block { - let mut statements = vec![]; - let mut value = None; - - for pair in pair.into_inner() { - match pair.as_rule() { - Rule::statement => statements.push(parse_statement(pair)), - Rule::expr => value = Some(parse_expression(pair)), - _ => panic!("unexpected rule {:?} in block", pair.as_rule()), - } - } - - Block { - statements, - value, - typ: Type::Undefined, - } -} - -fn parse_statement(pair: Pair) -> Statement { - let pair = pair.into_inner().next().unwrap(); - match pair.as_rule() { - Rule::assign_statement => { - let mut pairs = pair.into_inner(); - let identifier = pairs.next().unwrap().as_str().to_string(); - let expr = parse_expression(pairs.next().unwrap()); - Statement::AssignStatement(identifier, Box::new(expr)) - } - Rule::declare_statement => { - let mut pairs = pair.into_inner(); - let identifier = pairs.next().unwrap().as_str().to_string(); - let expr = parse_expression(pairs.next().unwrap()); - Statement::DeclareStatement(identifier, Box::new(expr)) - } - Rule::return_statement => { - let expr = if let Some(pair) = pair.into_inner().next() { - Some(parse_expression(pair)) - } else { - None - }; - Statement::ReturnStatement(expr) - } - Rule::call_statement => { - let call = parse_call(pair.into_inner().next().unwrap()); - Statement::CallStatement(Box::new(call)) - } - Rule::use_statement => { - let import = parse_import(pair.into_inner().next().unwrap()); - Statement::UseStatement(Box::new(import)) - } - Rule::if_statement => { - let mut pairs = pair.into_inner(); - let condition = parse_expression(pairs.next().unwrap()); - let block = parse_block(pairs.next().unwrap()); - if pairs.next().is_some() { - todo!("implement if-statements with else branch (and else if)") - } - Statement::IfStatement(Box::new(condition), Box::new(block)) - } - Rule::while_statement => { - let mut pairs = pair.into_inner(); - let condition = parse_expression(pairs.next().unwrap()); - let block = parse_block(pairs.next().unwrap()); - Statement::WhileStatement(Box::new(condition), Box::new(block)) - } - _ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()), - } -} - -fn parse_import(pair: Pair) -> Import { - Import(pair.as_str().to_string()) -} - -fn parse_call(pair: Pair) -> Call { - let mut pairs = pair.into_inner(); - // TODO: support calls on more than identifiers (needs grammar change) - let callee = Expr::Identifier { - name: pairs.next().unwrap().as_str().to_string(), - typ: Type::Undefined, - }; - let args: Vec = pairs - .next() - .unwrap() - .into_inner() - .map(parse_expression) - .collect(); - - Call { - callee: Box::new(callee), - args, - typ: Type::Undefined, - } -} - -fn parse_expression(pair: Pair) -> Expr { - let pairs = pair.into_inner(); - PRATT_PARSER - .map_primary(|primary| match primary.as_rule() { - Rule::integer_literal => Expr::IntegerLiteral(primary.as_str().parse().unwrap()), - Rule::float_literal => Expr::FloatLiteral(primary.as_str().parse().unwrap()), - Rule::string_literal => Expr::StringLiteral( - primary - .into_inner() - .next() - .unwrap() - .as_str() - .parse() - .unwrap(), - ), - Rule::expr => parse_expression(primary), - Rule::ident => Expr::Identifier { - name: primary.as_str().to_string(), - typ: Type::Undefined, - }, - Rule::call => Expr::Call(Box::new(parse_call(primary))), - Rule::block => Expr::Block(Box::new(parse_block(primary))), - Rule::if_expr => { - let mut pairs = primary.into_inner(); - let condition = parse_expression(pairs.next().unwrap()); - let true_block = parse_block(pairs.next().unwrap()); - let else_value = parse_expression(pairs.next().unwrap()); - Expr::IfExpr { - cond: Box::new(condition), - then_body: Box::new(true_block), - else_body: Box::new(else_value), - typ: Type::Undefined, + let import = self.parse_import(pair.into_inner().next().unwrap()); + Statement::UseStatement { + import: Box::new(import), + span, } } - Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() { - "true" => true, - "false" => false, - _ => unreachable!(), - }), - _ => unreachable!( - "Unexpected rule '{:?}' in primary expression", - primary.as_rule() - ), - }) - .map_infix(|lhs, op, rhs| { - let operator = match op.as_rule() { - Rule::add => BinaryOperator::Add, - Rule::subtract => BinaryOperator::Sub, - Rule::multiply => BinaryOperator::Mul, - Rule::divide => BinaryOperator::Div, - Rule::modulo => BinaryOperator::Modulo, - Rule::equal => BinaryOperator::Equal, - Rule::not_equal => BinaryOperator::NotEqual, - Rule::and => BinaryOperator::And, - Rule::or => BinaryOperator::Or, - _ => unreachable!(), - }; - Expr::BinaryExpression { - lhs: Box::new(lhs), - op: operator, - rhs: Box::new(rhs), - typ: Type::Undefined, + Rule::if_statement => { + let mut pairs = pair.into_inner(); + let condition = self.parse_expression(pairs.next().unwrap()); + let block = self.parse_block(pairs.next().unwrap()); + if pairs.next().is_some() { + todo!("implement if-statements with else branch (and else if)") + } + Statement::IfStatement { + condition: Box::new(condition), + then_block: Box::new(block), + span, + } } - }) - .map_prefix(|op, inner| { - let operator = match op.as_rule() { - Rule::not => UnaryOperator::Not, - _ => unreachable!(), - }; - Expr::UnaryExpression { - op: operator, - inner: Box::new(inner), + Rule::while_statement => { + let mut pairs = pair.into_inner(); + let condition = self.parse_expression(pairs.next().unwrap()); + let block = self.parse_block(pairs.next().unwrap()); + Statement::WhileStatement { + condition: Box::new(condition), + loop_block: Box::new(block), + span, + } } - }) - .parse(pairs) -} - -fn parse_parameter(pair: Pair) -> Parameter { - assert!(pair.as_rule() == Rule::parameter); - let mut pair = pair.into_inner(); - let name = pair.next().unwrap().as_str().to_string(); - let typ = Type::from(pair.next().unwrap().as_str()); - Parameter { name, typ } -} - -fn parse_definition(pair: Pair) -> Definition { - match pair.as_rule() { - Rule::func_def => { - let line_col = pair.line_col(); - let mut pairs = pair.into_inner(); - let name = pairs.next().unwrap().as_str().to_string(); - let parameters: Vec = pairs - .next() - .unwrap() - .into_inner() - .map(parse_parameter) - .collect(); - let pair = pairs.next().unwrap(); - // Before the block there is an optional return type - let (return_type, pair) = match pair.as_rule() { - Rule::ident => (Some(Type::from(pair.as_str())), pairs.next().unwrap()), - Rule::block => (None, pair), - _ => unreachable!( - "Unexpected rule '{:?}' in function definition, expected return type or block", - pair.as_rule() - ), - }; - let body = parse_block(pair); - let body = Box::new(body); - Definition::FunctionDefinition(FunctionDefinition { - name, - parameters, - return_type, - body, - location: Location { line_col }, - }) + _ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()), + } + } + + fn parse_import(&self, pair: Pair) -> Import { + Import(pair.as_str().to_string()) + } + + fn parse_call(&self, pair: Pair) -> Call { + let mut pairs = pair.into_inner(); + // TODO: support calls on more than identifiers (needs grammar change) + + let pair = pairs.next().unwrap(); + let callee = SExpr { + expr: Expr::Identifier { + name: pair.as_str().to_string(), + typ: Type::Undefined, + }, + span: self.make_span(&pair), + }; + let args: Vec = pairs + .next() + .unwrap() + .into_inner() + .map(|arg| self.parse_expression(arg)) + .collect(); + + Call { + callee: Box::new(callee), + args, + typ: Type::Undefined, + } + } + + fn parse_expression(&self, pair: Pair) -> SExpr { + let span = self.make_span(&pair); + let pairs = pair.into_inner(); + let mut map = PRATT_PARSER + .map_primary(|primary| { + let span = self.make_span(&primary); + match primary.as_rule() { + Rule::integer_literal => SExpr { + expr: Expr::IntegerLiteral(primary.as_str().parse().unwrap()), + span, + }, + + Rule::float_literal => SExpr { + expr: Expr::FloatLiteral(primary.as_str().parse().unwrap()), + span, + }, + + Rule::string_literal => SExpr { + expr: Expr::StringLiteral( + primary + .into_inner() + .next() + .unwrap() + .as_str() + .parse() + .unwrap(), + ), + span, + }, + + Rule::expr => self.parse_expression(primary), + + Rule::ident => SExpr { + expr: Expr::Identifier { + name: primary.as_str().to_string(), + typ: Type::Undefined, + }, + span, + }, + + Rule::call => SExpr { + expr: Expr::Call(Box::new(self.parse_call(primary))), + span, + }, + + Rule::block => SExpr { + expr: Expr::Block(Box::new(self.parse_block(primary))), + span, + }, + + Rule::if_expr => { + let mut pairs = primary.into_inner(); + let condition = self.parse_expression(pairs.next().unwrap()); + let true_block = self.parse_block(pairs.next().unwrap()); + let else_value = self.parse_expression(pairs.next().unwrap()); + SExpr { + expr: Expr::IfExpr { + cond: Box::new(condition), + then_body: Box::new(true_block), + else_body: Box::new(else_value), + typ: Type::Undefined, + }, + span, + } + } + + Rule::boolean_literal => SExpr { + expr: Expr::BooleanLiteral(match primary.as_str() { + "true" => true, + "false" => false, + _ => unreachable!(), + }), + span, + }, + + _ => unreachable!( + "Unexpected rule '{:?}' in primary expression", + primary.as_rule() + ), + } + }) + .map_infix(|lhs, op, rhs| { + let operator = match op.as_rule() { + Rule::add => BinaryOperator::Add, + Rule::subtract => BinaryOperator::Sub, + Rule::multiply => BinaryOperator::Mul, + Rule::divide => BinaryOperator::Div, + Rule::modulo => BinaryOperator::Modulo, + Rule::equal => BinaryOperator::Equal, + Rule::not_equal => BinaryOperator::NotEqual, + Rule::and => BinaryOperator::And, + Rule::or => BinaryOperator::Or, + _ => unreachable!(), + }; + let expr = Expr::BinaryExpression(BinaryExpression { + lhs: Box::new(lhs), + op: operator, + op_span: self.make_span(&op), + rhs: Box::new(rhs), + typ: Type::Undefined, + }); + SExpr { expr, span } + }) + .map_prefix(|op, inner| { + let operator = match op.as_rule() { + Rule::not => UnaryOperator::Not, + _ => unreachable!(), + }; + let expr = Expr::UnaryExpression { + op: operator, + inner: Box::new(inner), + }; + SExpr { expr, span } + }); + map.parse(pairs) + } + + fn parse_parameter(&self, pair: Pair) -> Parameter { + assert!(pair.as_rule() == Rule::parameter); + let mut pair = pair.into_inner(); + let name = pair.next().unwrap().as_str().to_string(); + let typ = Type::from(pair.next().unwrap().as_str()); + Parameter { name, typ } + } + + fn parse_definition(&self, pair: Pair) -> Definition { + match pair.as_rule() { + Rule::func_def => { + let span = self.make_span(&pair); + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let parameters: Vec = pairs + .next() + .unwrap() + .into_inner() + .map(|param| self.parse_parameter(param)) + .collect(); + let pair = pairs.next().unwrap(); + // Before the block there is an optional return type + let (return_type, return_type_span, pair) = match pair.as_rule() { + Rule::ident => ( + Some(Type::from(pair.as_str())), + Some(self.make_span(&pair)), + pairs.next().unwrap(), + ), + Rule::block => (None, None, pair), + _ => unreachable!( + "Unexpected rule '{:?}' in function definition, expected return type or block", + pair.as_rule() + ), + }; + let body = self.parse_block(pair); + let body = Box::new(body); + Definition::FunctionDefinition(FunctionDefinition { + name, + parameters, + return_type, + return_type_span, + span, + body, + }) + } + _ => panic!("unexpected node for definition: {:?}", pair.as_rule()), + } + } + + fn make_span(&self, pair: &Pair) -> Span { + let span = pair.as_span(); + Span { + source: self.source, + start: span.start(), + end: span.end(), } - _ => panic!("unexpected node for definition: {:?}", pair.as_rule()), } } diff --git a/src/parsing/mod.rs b/src/parsing/mod.rs index e7f2cd5..8cf64e4 100644 --- a/src/parsing/mod.rs +++ b/src/parsing/mod.rs @@ -1,4 +1,18 @@ mod backend; mod tests; -pub use self::backend::pest::{parse_file, parse_module, parse_as_module}; +use crate::ast::{Module, ModulePath, SourceId}; + +pub trait Parser: Default { + fn parse_file(&mut self, path: &std::path::Path, id: SourceId) -> anyhow::Result; + + fn parse_as_module( + &mut self, + source: &str, + path: ModulePath, + id: SourceId, + ) -> anyhow::Result; +} + +pub use self::backend::pest::Parser as PestParser; +pub use PestParser as DefaultParser; diff --git a/src/parsing/tests.rs b/src/parsing/tests.rs index 634de2a..0bedf1b 100644 --- a/src/parsing/tests.rs +++ b/src/parsing/tests.rs @@ -1,12 +1,17 @@ +#[cfg(test)] +use pretty_assertions::assert_eq; + #[test] fn test_addition_function() { - use crate::ast::{expr::Expr, *}; - use crate::parsing::backend::pest::parse_as_module; - use crate::typing::Type; + use crate::ast::*; + use crate::parsing::*; + use crate::typing::*; let source = "fn add(a: int, b: int) int { a + b }"; let path = ModulePath::from("test"); - let module = parse_as_module(&source, path.clone()).expect("parsing error"); + let module = DefaultParser::default() + .parse_as_module(source, path.clone(), 0) + .expect("parsing error"); let expected_module = Module { file: None, @@ -26,21 +31,61 @@ fn test_addition_function() { return_type: Some(Type::Int), body: Box::new(Block { statements: vec![], - value: Some(Expr::BinaryExpression { - lhs: Box::new(Expr::Identifier { - name: Identifier::from("a"), + value: Some(SExpr { + expr: Expr::BinaryExpression(BinaryExpression { + lhs: Box::new(SExpr { + expr: Expr::Identifier { + name: Identifier::from("a"), + typ: Type::Undefined, + }, + span: Span { + source: 0, + start: 29, + end: 30, + }, + }), + op: BinaryOperator::Add, + op_span: Span { + source: 0, + start: 31, + end: 32, + }, + rhs: Box::new(SExpr { + expr: Expr::Identifier { + name: Identifier::from("b"), + typ: Type::Undefined, + }, + span: Span { + source: 0, + start: 33, + end: 34, + }, + }), typ: Type::Undefined, }), - op: BinaryOperator::Add, - rhs: Box::new(Expr::Identifier { - name: Identifier::from("b"), - typ: Type::Undefined, - }), - typ: Type::Undefined, + span: Span { + source: 0, + start: 29, + end: 34, + }, }), typ: Type::Undefined, + span: Some(Span { + source: 0, + start: 27, + end: source.len(), + }), + }), + span: Span { + source: 0, + start: 0, + end: source.len(), + }, + return_type_span: Some(Span { + source: 0, + start: 23, + end: 26, }), - location: Location { line_col: (1, 1) }, }], path, }; diff --git a/src/source.rs b/src/source.rs new file mode 100644 index 0000000..9f53cd7 --- /dev/null +++ b/src/source.rs @@ -0,0 +1,22 @@ +use crate::ast::SourceId; +use ariadne::FileCache; + +pub struct SourceCache { + pub paths: Vec, + pub file_cache: FileCache, +} + +impl ariadne::Cache for SourceCache { + type Storage = String; + + fn fetch( + &mut self, + id: &SourceId, + ) -> Result<&ariadne::Source, Box> { + self.file_cache.fetch(&self.paths[*id as usize]) + } + + fn display<'a>(&self, id: &'a SourceId) -> Option> { + Some(Box::new(format!("{}", self.paths[*id as usize].display()))) + } +} diff --git a/src/typing/error.rs b/src/typing/error.rs index 7eea1a5..0260f97 100644 --- a/src/typing/error.rs +++ b/src/typing/error.rs @@ -1,8 +1,10 @@ -use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypingContext}; +use ariadne::{ColorGenerator, Fmt, Label, Report, ReportKind, Span as _}; +use std::fmt::Debug; -use super::UnaryOperator; +use super::{Span, UnaryOperator}; +use crate::typing::{BinaryOperator, Identifier, ModulePath, Type}; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct TypeError { pub file: Option, pub module: ModulePath, @@ -10,72 +12,31 @@ pub struct TypeError { pub kind: TypeErrorKind, } -impl std::fmt::Display for TypeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("Error\n")?; - if let Some(path) = &self.file { - f.write_fmt(format_args!(" in file {}\n", path.display()))?; - } - f.write_fmt(format_args!(" in module {}\n", self.module))?; - if let Some(name) = &self.function { - f.write_fmt(format_args!(" in function {}\n", name))?; - } - f.write_fmt(format_args!("{:#?}", self.kind))?; - Ok(()) - } +#[derive(PartialEq, Debug)] +pub struct TypeAndSpan { + pub ty: Type, + pub span: Span, } -#[derive(Default)] -pub struct TypeErrorBuilder { - file: Option, - module: Option, - function: Option, - kind: Option, -} - -impl TypeError { - pub fn builder() -> TypeErrorBuilder { - TypeErrorBuilder::default() - } -} - -impl TypeErrorBuilder { - pub fn context(mut self, ctx: &TypingContext) -> Self { - self.file = ctx.file.clone(); - self.module = Some(ctx.module.clone()); - self.function = ctx.function.clone(); - self - } - - pub fn kind(mut self, kind: TypeErrorKind) -> Self { - self.kind = Some(kind); - self - } - - pub fn build(self) -> TypeError { - TypeError { - file: self.file, - module: self.module.expect("TypeError builder is missing module"), - function: self.function, - kind: self.kind.expect("TypeError builder is missing kind"), - } - } +#[derive(PartialEq, Debug)] +pub struct BinOpAndSpan { + pub op: BinaryOperator, + pub span: Span, } #[derive(Debug, PartialEq)] pub enum TypeErrorKind { InvalidBinaryOperator { - operator: BinaryOperator, - lht: Type, - rht: Type, + operator: BinOpAndSpan, + lhs: TypeAndSpan, + rhs: TypeAndSpan, }, BlockTypeDoesNotMatchFunctionType { block_type: Type, - function_type: Type, }, ReturnTypeDoesNotMatchFunctionType { - function_type: Type, - return_type: Type, + return_expr: Option, + return_stmt: TypeAndSpan, }, UnknownIdentifier { identifier: String, @@ -86,7 +47,6 @@ pub enum TypeErrorKind { }, AssignUndeclared, VariableRedeclaration, - ReturnStatementsMismatch, UnknownFunctionCalled(Identifier), WrongFunctionArguments, ConditionIsNotBool, @@ -96,3 +56,132 @@ pub enum TypeErrorKind { inner: Type, }, } + +impl TypeError { + pub fn to_report(&self, ast: &crate::ast::Module) -> Report { + let mut colors = ColorGenerator::new(); + let c0 = colors.next(); + let c1 = colors.next(); + colors.next(); + let c2 = colors.next(); + + match &self.kind { + TypeErrorKind::InvalidBinaryOperator { operator, lhs, rhs } => { + Report::build(ReportKind::Error, 0u32, 0) + .with_message(format!( + "Invalid binary operation {} between {} and {}", + operator.op.to_string().fg(c0), + lhs.ty.to_string().fg(c1), + rhs.ty.to_string().fg(c2), + )) + .with_labels([ + Label::new(operator.span).with_color(c0), + Label::new(lhs.span) + .with_message(format!("This has type {}", lhs.ty.to_string().fg(c1))) + .with_color(c1) + .with_order(2), + Label::new(rhs.span) + .with_message(format!("This has type {}", rhs.ty.to_string().fg(c2))) + .with_color(c2) + .with_order(1), + ]) + .finish() + } + + TypeErrorKind::BlockTypeDoesNotMatchFunctionType { block_type } => { + let function = ast + .functions + .iter() + .find(|f| f.name == *self.function.as_ref().unwrap()) + .unwrap(); + + let block_color = c0; + let signature_color = c1; + + let span = function.body.value.as_ref().unwrap().span; + + let report = Report::build(ReportKind::Error, 0u32, 0) + .with_message("Function body does not match the signature") + .with_labels([ + Label::new(function.body.span.unwrap()) + .with_message("In this function's body") + .with_color(c2), + Label::new(span) + .with_message(format!( + "Returned expression has type {} but the function should return {}", + block_type.to_string().fg(block_color), + function + .return_type + .as_ref() + .unwrap_or(&Type::Unit) + .to_string() + .fg(signature_color) + )) + .with_color(block_color), + ]); + + let report = + report.with_note("The last expression of a function's body is returned"); + + let report = if function.return_type.is_none() { + report.with_help( + "You may need to add the return type to the function's signature", + ) + } else { + report + }; + + report.finish() + } + + TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { + return_expr, + return_stmt, + } => { + let function = ast + .functions + .iter() + .find(|f| f.name == *self.function.as_ref().unwrap()) + .unwrap(); + + let is_bare_return = return_expr.is_none(); + + let report = Report::build(ReportKind::Error, *return_stmt.span.source(), 0) + .with_message("Return type does not match the function's signature") + .with_label( + Label::new(return_expr.as_ref().unwrap_or(return_stmt).span) + .with_color(c1) + .with_message(if is_bare_return { + format!("Bare return has type {}", Type::Unit.to_string().fg(c1)) + } else { + format!( + "This expression has type {}", + return_stmt.ty.to_string().fg(c1) + ) + }), + ); + + let report = if let Some(ret_ty_span) = function.return_type_span { + report.with_label(Label::new(ret_ty_span).with_color(c0).with_message(format!( + "The signature shows {}", + function.return_type.as_ref().unwrap().to_string().fg(c0) + ))) + } else { + report + }; + + let report = if function.return_type.is_none() { + report.with_help( + "You may need to add the return type to the function's signature", + ) + } else { + report + }; + + report.finish() + } + + _ => todo!(), + } + } +} diff --git a/src/typing/mod.rs b/src/typing/mod.rs index 14acbf1..5fce74a 100644 --- a/src/typing/mod.rs +++ b/src/typing/mod.rs @@ -1,11 +1,14 @@ use std::collections::HashMap; use std::fmt::Display; +use BinaryExpression; +use ReturnStatement; + use crate::ast::ModulePath; use crate::ast::*; mod error; -use crate::typing::error::{TypeError, TypeErrorKind}; +use crate::typing::error::{TypeAndSpan, TypeError, TypeErrorKind}; #[cfg(test)] mod tests; @@ -62,11 +65,11 @@ impl From<&str> for Type { #[derive(Debug, PartialEq, Clone)] pub struct Signature(Vec, Type); -impl Into for Signature { - fn into(self) -> Type { +impl From for Type { + fn from(val: Signature) -> Self { Type::Function { - params: self.0, - returns: Box::new(self.1), + params: val.0, + returns: Box::new(val.1), } } } @@ -79,12 +82,13 @@ impl FunctionDefinition { } } +#[derive(Debug, PartialEq)] pub struct CheckedModule(pub Module); impl Module { - pub fn type_check(mut self) -> Result { + pub fn type_check(&mut self) -> Result<(), Vec> { let mut ctx = TypingContext::new(self.path.clone()); - ctx.file = self.file.clone(); + ctx.file.clone_from(&self.file); // Register all function signatures for func in &self.functions { @@ -95,13 +99,21 @@ impl Module { // TODO: add signatures of imported functions (even if they have not been checked) + let mut errors = Vec::new(); + // Type-check the function bodies and complete all type placeholders for func in &mut self.functions { - func.typ(&mut ctx)?; + if let Err(e) = func.typ(&mut ctx) { + errors.push(e); + }; ctx.variables.clear(); } - Ok(CheckedModule(self)) + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } } } @@ -128,6 +140,15 @@ impl TypingContext { variables: Default::default(), } } + + pub fn make_error(&self, kind: TypeErrorKind) -> TypeError { + TypeError { + kind, + file: self.file.clone(), + module: self.module.clone(), + function: self.function.clone(), + } + } } /// Trait for nodes which have a deducible type. @@ -148,40 +169,13 @@ impl TypeCheck for FunctionDefinition { let body_type = self.body.typ(ctx)?; - // If the return type is not specified, it is unit. - if self.return_type.is_none() { - self.return_type = Some(Type::Unit) - } - // Check coherence with the body's type. - if *self.return_type.as_ref().unwrap() != body_type { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType { + if *self.return_type.as_ref().unwrap_or(&Type::Unit) != body_type { + return Err( + ctx.make_error(TypeErrorKind::BlockTypeDoesNotMatchFunctionType { block_type: body_type.clone(), - function_type: self.return_type.as_ref().unwrap().clone(), - }) - .build()); - } - - // Check coherence with return statements. - - for statement in &mut self.body.statements { - if let Statement::ReturnStatement(value) = statement { - let ret_type = match value { - Some(expr) => expr.typ(ctx)?, - None => Type::Unit, - }; - if ret_type != *self.return_type.as_ref().unwrap() { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { - function_type: self.return_type.as_ref().unwrap().clone(), - return_type: ret_type.clone(), - }) - .build()); - } - } + }), + ); } Ok(self.return_type.clone().unwrap()) @@ -190,79 +184,65 @@ impl TypeCheck for FunctionDefinition { impl TypeCheck for Block { fn typ(&mut self, ctx: &mut TypingContext) -> Result { - let mut return_typ: Option = None; - // Check declarations and assignments. for statement in &mut self.statements { match statement { - Statement::DeclareStatement(ident, expr) => { + Statement::DeclareStatement { + lhs: ident, + rhs: expr, + .. + } => { let typ = expr.typ(ctx)?; if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) { // TODO: Shadowing? (illegal for now) - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::VariableRedeclaration) - .build()); + return Err(ctx.make_error(TypeErrorKind::VariableRedeclaration)); } } - Statement::AssignStatement(ident, expr) => { + Statement::AssignStatement { + lhs: ident, + rhs: expr, + .. + } => { let rhs_typ = expr.typ(ctx)?; let Some(lhs_typ) = ctx.variables.get(ident) else { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::AssignUndeclared) - .build()); + return Err(ctx.make_error(TypeErrorKind::AssignUndeclared)); }; // Ensure same type on both sides. if rhs_typ != *lhs_typ { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::AssignmentMismatch { - lht: lhs_typ.clone(), - rht: rhs_typ.clone(), - }) - .build()); + return Err(ctx.make_error(TypeErrorKind::AssignmentMismatch { + lht: lhs_typ.clone(), + rht: rhs_typ.clone(), + })); } } - Statement::ReturnStatement(maybe_expr) => { - let expr_typ = if let Some(expr) = maybe_expr { - expr.typ(ctx)? - } else { - Type::Unit - }; - if let Some(typ) = &return_typ { - if expr_typ != *typ { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ReturnStatementsMismatch) - .build()); - } - } else { - return_typ = Some(expr_typ.clone()); - } + Statement::ReturnStatement(return_stmt) => { + return_stmt.typ(ctx)?; } - Statement::CallStatement(call) => { + Statement::CallStatement { call, span: _ } => { call.typ(ctx)?; } - Statement::UseStatement(_path) => { + Statement::UseStatement { .. } => { // TODO: import the signatures (and types) + todo!() } - Statement::IfStatement(cond, block) => { + Statement::IfStatement { + condition: cond, + then_block: block, + .. + } => { if cond.typ(ctx)? != Type::Bool { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ConditionIsNotBool) - .build()); + return Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool)); } block.typ(ctx)?; } - Statement::WhileStatement(cond, block) => { + Statement::WhileStatement { + condition: cond, + loop_block: block, + span: _, + } => { if cond.typ(ctx)? != Type::Bool { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ConditionIsNotBool) - .build()); + return Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool)); } block.typ(ctx)?; } @@ -277,25 +257,19 @@ impl TypeCheck for Block { self.typ = Type::Unit; Ok(Type::Unit) } - - // TODO/FIXME: find a way to return `return_typ` so that the - // top-level block (the function) can check if this return type - // (and eventually those from other block) matches the type of - // the function. } } impl TypeCheck for Call { fn typ(&mut self, ctx: &mut TypingContext) -> Result { - match &mut *self.callee { - Expr::Identifier { name, typ } => { + match &mut self.callee.expr { + Expr::Identifier { name, typ, .. } => { let signature = match ctx.functions.get(name) { Some(sgn) => sgn.clone(), None => { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::UnknownFunctionCalled(name.clone())) - .build()) + return Err( + ctx.make_error(TypeErrorKind::UnknownFunctionCalled(name.clone())) + ) } }; @@ -315,10 +289,7 @@ impl TypeCheck for Call { if args_types == *params_types { Ok(self.typ.clone()) } else { - Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::WrongFunctionArguments) - .build()) + Err(ctx.make_error(TypeErrorKind::WrongFunctionArguments)) } } _ => unimplemented!("cannot call on expression other than identifier"), @@ -329,36 +300,40 @@ impl TypeCheck for Call { impl TypeCheck for Expr { fn typ(&mut self, ctx: &mut TypingContext) -> Result { match self { - Expr::Identifier { name, typ } => { + Expr::Identifier { name, typ, .. } => { if let Some(ty) = ctx.variables.get(name) { *typ = ty.clone(); Ok(typ.clone()) } else { - Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::UnknownIdentifier { - identifier: name.clone(), - }) - .build()) + Err(ctx.make_error(TypeErrorKind::UnknownIdentifier { + identifier: name.clone(), + })) } } - Expr::BooleanLiteral(_) => Ok(Type::Bool), - Expr::IntegerLiteral(_) => Ok(Type::Int), - Expr::FloatLiteral(_) => Ok(Type::Float), - Expr::UnaryExpression { op, inner } => { + Expr::BooleanLiteral(..) => Ok(Type::Bool), + Expr::IntegerLiteral(..) => Ok(Type::Int), + Expr::FloatLiteral(..) => Ok(Type::Float), + Expr::UnaryExpression { op, inner, .. } => { let inner_type = &inner.typ(ctx)?; match (&op, inner_type) { (UnaryOperator::Not, Type::Bool) => Ok(Type::Bool), - _ => Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidUnaryOperator { - operator: *op, - inner: inner_type.clone(), - }) - .build()), + _ => Err(ctx.make_error(TypeErrorKind::InvalidUnaryOperator { + operator: *op, + inner: inner_type.clone(), + })), } } - Expr::BinaryExpression { lhs, op, rhs, typ } => { + Expr::BinaryExpression(BinaryExpression { + lhs, + op, + rhs, + typ, + op_span, + }) => { + let operator = error::BinOpAndSpan { + op: op.clone(), + span: *op_span, + }; let ty = match op { BinaryOperator::Add | BinaryOperator::Sub @@ -368,32 +343,39 @@ impl TypeCheck for Expr { | BinaryOperator::Or => { let left_type = &lhs.typ(ctx)?; let right_type = &rhs.typ(ctx)?; + match (left_type, right_type) { (Type::Int, Type::Int) => Ok(Type::Int), (Type::Float, Type::Float) => Ok(Type::Float), (Type::Bool, Type::Bool) => Ok(Type::Bool), - (_, _) => Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidBinaryOperator { - operator: op.clone(), - lht: left_type.clone(), - rht: right_type.clone(), - }) - .build()), + (_, _) => Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator { + operator, + lhs: TypeAndSpan { + ty: left_type.clone(), + span: lhs.span, + }, + rhs: TypeAndSpan { + ty: right_type.clone(), + span: rhs.span, + }, + })), } } BinaryOperator::Equal | BinaryOperator::NotEqual => { let lhs_type = lhs.typ(ctx)?; let rhs_type = rhs.typ(ctx)?; if lhs_type != rhs_type { - return Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidBinaryOperator { - operator: op.clone(), - lht: lhs_type.clone(), - rht: rhs_type.clone(), - }) - .build()); + return Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator { + operator, + lhs: TypeAndSpan { + ty: lhs_type.clone(), + span: lhs.span, + }, + rhs: TypeAndSpan { + ty: rhs_type.clone(), + span: rhs.span, + }, + })); } Ok(Type::Bool) } @@ -402,14 +384,17 @@ impl TypeCheck for Expr { let rhs_type = lhs.typ(ctx)?; match (&lhs_type, &rhs_type) { (Type::Int, Type::Int) => Ok(Type::Int), - _ => Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::InvalidBinaryOperator { - operator: op.clone(), - lht: lhs_type.clone(), - rht: rhs_type.clone(), - }) - .build()), + _ => Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator { + operator, + lhs: TypeAndSpan { + ty: lhs_type.clone(), + span: lhs.span, + }, + rhs: TypeAndSpan { + ty: rhs_type.clone(), + span: rhs.span, + }, + })), } } }; @@ -427,18 +412,12 @@ impl TypeCheck for Expr { typ, } => { if cond.typ(ctx)? != Type::Bool { - Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::ConditionIsNotBool) - .build()) + Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool)) } else { let then_body_type = then_body.typ(ctx)?; let else_type = else_body.typ(ctx)?; if then_body_type != else_type { - Err(TypeError::builder() - .context(ctx) - .kind(TypeErrorKind::IfElseMismatch) - .build()) + Err(ctx.make_error(TypeErrorKind::IfElseMismatch)) } else { // XXX: opt: return ref to avoid cloning *typ = then_body_type.clone(); @@ -449,3 +428,39 @@ impl TypeCheck for Expr { } } } + +impl TypeCheck for ReturnStatement { + fn typ(&mut self, ctx: &mut TypingContext) -> Result { + let ty = if let Some(expr) = &mut self.expr { + expr.typ(ctx)? + } else { + Type::Unit + }; + + // Check if the returned type is coherent with the function's signature + let func_type = &ctx.functions.get(ctx.function.as_ref().unwrap()).unwrap().1; + if ty != *func_type { + return Err( + ctx.make_error(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { + return_expr: self.expr.as_ref().map(|e| TypeAndSpan { + ty: ty.clone(), + span: e.span, + }), + return_stmt: TypeAndSpan { + ty: ty.clone(), + span: self.span, + }, + }), + ); + }; + + Ok(ty) + } +} + +impl TypeCheck for SExpr { + #[inline] + fn typ(&mut self, ctx: &mut TypingContext) -> Result { + self.expr.typ(ctx) + } +} diff --git a/src/typing/tests.rs b/src/typing/tests.rs index e579554..79fa13e 100644 --- a/src/typing/tests.rs +++ b/src/typing/tests.rs @@ -1,33 +1,40 @@ use crate::{ ast::ModulePath, - parsing::parse_as_module, - typing::{ - error::{TypeError, TypeErrorKind}, - BinaryOperator, Type, - }, + parsing::{DefaultParser, Parser}, + typing::error::*, + typing::*, }; +#[cfg(test)] +use pretty_assertions::assert_eq; + #[test] fn addition_int_and_float() { let source = "fn add(a: int, b: float) int { a + b }"; - let mut ast = parse_as_module(source, ModulePath::default()).unwrap(); + let mut ast = DefaultParser::default() + .parse_as_module(source, ModulePath::default(), 0) + .unwrap(); let res = ast.type_check(); - assert!(res.is_err_and(|e| e.kind - == TypeErrorKind::InvalidBinaryOperator { - operator: BinaryOperator::Add, - lht: Type::Int, - rht: Type::Float - })); + assert!(res.is_err_and(|errors| errors.len() == 1 + && matches!(errors[0].kind, TypeErrorKind::InvalidBinaryOperator { .. }))); } #[test] fn return_int_instead_of_float() { let source = "fn add(a: int, b: int) float { a + b }"; - let mut ast = parse_as_module(source, ModulePath::default()).unwrap(); + let mut ast = DefaultParser::default() + .parse_as_module(source, ModulePath::default(), 0) + .unwrap(); let res = ast.type_check(); - assert!(res.is_err_and(|e| e.kind - == TypeErrorKind::BlockTypeDoesNotMatchFunctionType { - block_type: Type::Int, - function_type: Type::Float - })); + assert_eq!( + res, + Err(vec![TypeError { + file: None, + module: ModulePath::default(), + function: Some("add".to_string()), + kind: TypeErrorKind::BlockTypeDoesNotMatchFunctionType { + block_type: Type::Int, + } + }]) + ); }