add pretty diagnostics

This commit is contained in:
Romain Paquet 2024-07-03 19:59:12 +02:00
parent e157bf036a
commit f415c4abbe
12 changed files with 1037 additions and 603 deletions

View file

@ -5,10 +5,15 @@ edition = "2021"
[dependencies] [dependencies]
clap = { version = "4.5.7", features = ["derive"] } clap = { version = "4.5.7", features = ["derive"] }
cranelift = "0.108.1" cranelift = "0.109.0"
cranelift-jit = "0.108.1" cranelift-jit = "0.109.0"
cranelift-module = "0.108.1" cranelift-module = "0.109.0"
cranelift-native = "0.108.1" cranelift-native = "0.109.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
pest = "2.7.4" pest = "2.7.4"
pest_derive = "2.7.4" pest_derive = "2.7.4"
ariadne = "0.4.1"
anyhow = "1.0.86"
[dev-dependencies]
pretty_assertions = "1.4.0"

View file

@ -1,17 +1,27 @@
use crate::ast::*; use crate::ast::*;
use crate::typing::Type; use crate::typing::Type;
#[derive(Debug, PartialEq)]
pub struct SExpr {
pub expr: Expr,
pub span: Span,
}
#[derive(Debug, PartialEq)]
pub struct BinaryExpression {
pub lhs: Box<SExpr>,
pub op: BinaryOperator,
pub op_span: Span,
pub rhs: Box<SExpr>,
pub typ: Type,
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Expr { pub enum Expr {
BinaryExpression { BinaryExpression(BinaryExpression),
lhs: Box<Expr>,
op: BinaryOperator,
rhs: Box<Expr>,
typ: Type,
},
UnaryExpression { UnaryExpression {
op: UnaryOperator, op: UnaryOperator,
inner: Box<Expr>, inner: Box<SExpr>,
}, },
Identifier { Identifier {
name: String, name: String,
@ -21,9 +31,9 @@ pub enum Expr {
Block(Box<Block>), Block(Box<Block>),
/// Last field is either Expr::Block or Expr::IfExpr /// Last field is either Expr::Block or Expr::IfExpr
IfExpr { IfExpr {
cond: Box<Expr>, cond: Box<SExpr>,
then_body: Box<Block>, then_body: Box<Block>,
else_body: Box<Expr>, else_body: Box<SExpr>,
typ: Type, typ: Type,
}, },
// Literals // Literals
@ -45,22 +55,12 @@ impl Block {
impl Expr { impl Expr {
pub fn ty(&self) -> Type { pub fn ty(&self) -> Type {
match self { match self {
Expr::BinaryExpression { Expr::BinaryExpression(BinaryExpression { typ, .. }) => typ.clone(),
lhs: _, Expr::UnaryExpression { inner, .. } => inner.ty(), // XXX: problems will arise here
op: _, Expr::Identifier { typ, .. } => typ.clone(),
rhs: _,
typ,
} => typ.clone(),
Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here
Expr::Identifier { name: _, typ } => typ.clone(),
Expr::Call(call) => call.typ.clone(), Expr::Call(call) => call.typ.clone(),
Expr::Block(block) => block.typ.clone(), Expr::Block(block) => block.typ.clone(),
Expr::IfExpr { Expr::IfExpr { typ, .. } => typ.clone(),
cond: _,
then_body: _,
else_body: _,
typ,
} => typ.clone(),
Expr::UnitLiteral => Type::Unit, Expr::UnitLiteral => Type::Unit,
Expr::BooleanLiteral(_) => Type::Bool, Expr::BooleanLiteral(_) => Type::Bool,
Expr::IntegerLiteral(_) => Type::Int, Expr::IntegerLiteral(_) => Type::Int,
@ -69,3 +69,10 @@ impl Expr {
} }
} }
} }
impl SExpr {
#[inline]
pub fn ty(&self) -> Type {
self.expr.ty()
}
}

View file

@ -1,9 +1,11 @@
pub mod expr; pub mod expr;
pub use expr::Expr; pub use expr::{BinaryExpression, Expr, SExpr};
use crate::typing::Type; use crate::typing::Type;
use std::path::Path;
use ariadne;
use std::{fmt::Display, path::Path};
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum BinaryOperator { pub enum BinaryOperator {
@ -20,6 +22,22 @@ pub enum BinaryOperator {
NotEqual, NotEqual,
} }
impl Display for BinaryOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
BinaryOperator::And => "&&",
BinaryOperator::Or => "||",
BinaryOperator::Add => "+",
BinaryOperator::Sub => "-",
BinaryOperator::Mul => "*",
BinaryOperator::Div => "/",
BinaryOperator::Modulo => "%",
BinaryOperator::Equal => "==",
BinaryOperator::NotEqual => "!=",
})
}
}
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum UnaryOperator { pub enum UnaryOperator {
Not, Not,
@ -27,12 +45,32 @@ pub enum UnaryOperator {
pub type Identifier = String; pub type Identifier = String;
#[derive(Debug, PartialEq)] pub type SourceId = u32;
pub struct Location {
pub line_col: (usize, usize), #[derive(Debug, PartialEq, Clone, Copy)]
pub struct Span {
pub source: SourceId,
pub start: usize,
pub end: usize,
} }
#[derive(Debug, PartialEq, Clone, Default)] impl ariadne::Span for Span {
type SourceId = SourceId;
fn source(&self) -> &Self::SourceId {
&self.source
}
fn start(&self) -> usize {
self.start
}
fn end(&self) -> usize {
self.end
}
}
#[derive(Debug, PartialEq, Clone, Default, Eq, Hash)]
pub struct ModulePath { pub struct ModulePath {
components: Vec<String>, components: Vec<String>,
} }
@ -65,7 +103,7 @@ impl From<&Path> for ModulePath {
.map(|component| match component { .map(|component| match component {
std::path::Component::Normal(n) => { std::path::Component::Normal(n) => {
if meta.is_file() { if meta.is_file() {
n.to_str().unwrap().split(".").nth(0).unwrap().to_string() n.to_str().unwrap().split('.').nth(0).unwrap().to_string()
} else if meta.is_dir() { } else if meta.is_dir() {
n.to_str().unwrap().to_string() n.to_str().unwrap().to_string()
} else { } else {
@ -91,22 +129,51 @@ impl From<&str> for ModulePath {
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
pub struct Import(pub String); pub struct Import(pub String);
#[derive(Debug, PartialEq)]
pub struct ReturnStatement {
pub expr: Option<SExpr>,
pub span: Span,
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Statement { pub enum Statement {
DeclareStatement(Identifier, Box<Expr>), DeclareStatement {
AssignStatement(Identifier, Box<Expr>), lhs: Identifier,
ReturnStatement(Option<Expr>), rhs: Box<SExpr>,
CallStatement(Box<Call>), span: Span,
UseStatement(Box<Import>), },
IfStatement(Box<Expr>, Box<Block>), AssignStatement {
WhileStatement(Box<Expr>, Box<Block>), lhs: Identifier,
rhs: Box<SExpr>,
span: Span,
},
ReturnStatement(ReturnStatement),
CallStatement {
call: Box<Call>,
span: Span,
},
UseStatement {
import: Box<Import>,
span: Span,
},
IfStatement {
condition: Box<SExpr>,
then_block: Box<Block>,
span: Span,
},
WhileStatement {
condition: Box<SExpr>,
loop_block: Box<Block>,
span: Span,
},
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct Block { pub struct Block {
pub statements: Vec<Statement>, pub statements: Vec<Statement>,
pub value: Option<Expr>, pub value: Option<SExpr>,
pub typ: Type, pub typ: Type,
pub span: Option<Span>,
} }
impl Block { impl Block {
@ -115,6 +182,7 @@ impl Block {
typ: Type::Unit, typ: Type::Unit,
statements: Vec::with_capacity(0), statements: Vec::with_capacity(0),
value: None, value: None,
span: None,
} }
} }
} }
@ -129,8 +197,9 @@ pub struct FunctionDefinition {
pub name: Identifier, pub name: Identifier,
pub parameters: Vec<Parameter>, pub parameters: Vec<Parameter>,
pub return_type: Option<Type>, pub return_type: Option<Type>,
pub return_type_span: Option<Span>,
pub body: Box<Block>, pub body: Box<Block>,
pub location: Location, pub span: Span,
} }
#[derive(Debug, PartialEq, Default)] #[derive(Debug, PartialEq, Default)]
@ -159,8 +228,8 @@ impl Module {
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct Call { pub struct Call {
pub callee: Box<Expr>, pub callee: Box<SExpr>,
pub args: Vec<Expr>, pub args: Vec<SExpr>,
pub typ: Type, pub typ: Type,
} }

View file

@ -1,15 +1,17 @@
use crate::{ use crate::{
ast::{ ast::{
self, BinaryOperator, ModulePath, UnaryOperator, self, expr::BinaryExpression, BinaryOperator, Expr, FunctionDefinition, ModulePath,
{expr::Expr, FunctionDefinition, Statement}, ReturnStatement, SourceId, Statement, UnaryOperator,
}, },
parsing, parsing::{DefaultParser, Parser},
typing::{CheckedModule, Type}, typing::Type,
SourceCache,
}; };
use ariadne::Cache as _;
use cranelift::{codegen::ir::UserFuncName, prelude::*}; use cranelift::{codegen::ir::UserFuncName, prelude::*};
use cranelift_jit::{JITBuilder, JITModule}; use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module}; use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module};
use std::{collections::HashMap, fs, ops::Deref}; use std::collections::HashMap;
/// The basic JIT class. /// The basic JIT class.
pub struct JIT { pub struct JIT {
@ -30,6 +32,9 @@ pub struct JIT {
/// Whether to print CLIR during compilation /// Whether to print CLIR during compilation
pub dump_clir: bool, pub dump_clir: bool,
/// Parser used to build the AST
pub parser: DefaultParser,
} }
impl Default for JIT { impl Default for JIT {
@ -59,18 +64,34 @@ impl Default for JIT {
data_desc: DataDescription::new(), data_desc: DataDescription::new(),
module, module,
dump_clir: false, dump_clir: false,
parser: DefaultParser::default(),
} }
} }
} }
impl JIT { impl JIT {
/// Compile source code into machine code. /// Compile source code into machine code.
pub fn compile(&mut self, input: &str, namespace: ModulePath) -> Result<*const u8, String> { pub fn compile(
&mut self,
input: &str,
namespace: ModulePath,
id: SourceId,
) -> Result<*const u8, String> {
let mut source_cache = (0u32, ariadne::Source::from(input));
// Parse the source code into an AST // Parse the source code into an AST
let ast = parsing::parse_as_module(input, namespace) let mut ast = self
.map_err(|x| format!("Parsing error: {x}"))? .parser
.type_check() .parse_as_module(input, namespace, id)
.map_err(|x| format!("Typing error: {x}"))?; .map_err(|x| format!("Parsing error: {x}"))?;
ast.type_check()
.map_err(|errors| {
errors
.iter()
.for_each(|e| e.to_report(&ast).eprint(&mut source_cache).unwrap());
})
.unwrap();
// Translate the AST into Cranelift IR // Translate the AST into Cranelift IR
self.translate(&ast)?; self.translate(&ast)?;
@ -89,20 +110,24 @@ impl JIT {
} }
} }
pub fn compile_file(&mut self, path: &str) -> Result<*const u8, String> { pub fn compile_file(
&mut self,
path: &str,
id: SourceId,
source_cache: &mut SourceCache,
) -> Result<*const u8, String> {
self.compile( self.compile(
fs::read_to_string(path) source_cache
.map_err(|x| format!("Cannot open {}: {}", path, x))? .fetch(&id)
.as_str(), .map(|s| s.text())
.map_err(|e| format!("{:?}", e))?,
AsRef::<std::path::Path>::as_ref(path).into(), AsRef::<std::path::Path>::as_ref(path).into(),
id,
) )
} }
/// Translate language AST into Cranelift IR. /// Translate language AST into Cranelift IR.
fn translate(&mut self, ast: &CheckedModule) -> Result<(), String> { fn translate(&mut self, ast: &ast::Module) -> Result<(), String> {
// Dump contract-holding wrapper type
let ast = &ast.0;
let mut signatures: Vec<Signature> = Vec::with_capacity(ast.functions.len()); let mut signatures: Vec<Signature> = Vec::with_capacity(ast.functions.len());
let mut func_ids: Vec<FuncId> = Vec::with_capacity(ast.functions.len()); let mut func_ids: Vec<FuncId> = Vec::with_capacity(ast.functions.len());
@ -199,7 +224,7 @@ impl JIT {
// Emit the final return instruction. // Emit the final return instruction.
if let Some(return_expr) = &function.body.value { if let Some(return_expr) = &function.body.value {
let return_value = translator.translate_expr(&return_expr); let return_value = translator.translate_expr(&return_expr.expr);
translator.builder.ins().return_(&[return_value]); translator.builder.ins().return_(&[return_value]);
} else { } else {
translator.builder.ins().return_(&[]); translator.builder.ins().return_(&[]);
@ -234,18 +259,26 @@ struct FunctionTranslator<'a> {
impl<'a> FunctionTranslator<'a> { impl<'a> FunctionTranslator<'a> {
fn translate_statement(&mut self, stmt: &Statement) -> Option<Value> { fn translate_statement(&mut self, stmt: &Statement) -> Option<Value> {
match stmt { match stmt {
Statement::AssignStatement(name, expr) => { Statement::AssignStatement {
lhs: name,
rhs: expr,
..
} => {
// `def_var` is used to write the value of a variable. Note that // `def_var` is used to write the value of a variable. Note that
// variables can have multiple definitions. Cranelift will // variables can have multiple definitions. Cranelift will
// convert them into SSA form for itself automatically. // convert them into SSA form for itself automatically.
let new_value = self.translate_expr(expr); let new_value = self.translate_expr(&expr.expr);
let variable = self.variables.get(name).unwrap(); let variable = self.variables.get(name).unwrap();
self.builder.def_var(*variable, new_value); self.builder.def_var(*variable, new_value);
Some(new_value) Some(new_value)
} }
Statement::DeclareStatement(name, expr) => { Statement::DeclareStatement {
let value = self.translate_expr(expr); lhs: name,
rhs: expr,
..
} => {
let value = self.translate_expr(&expr.expr);
let variable = Variable::from_u32(self.variables.len() as u32); let variable = Variable::from_u32(self.variables.len() as u32);
self.builder self.builder
.declare_var(variable, self.translate_type(&expr.ty())); .declare_var(variable, self.translate_type(&expr.ty()));
@ -254,10 +287,12 @@ impl<'a> FunctionTranslator<'a> {
Some(value) Some(value)
} }
Statement::ReturnStatement(maybe_expr) => { Statement::ReturnStatement(ReturnStatement {
expr: maybe_expr, ..
}) => {
// TODO: investigate tail call // TODO: investigate tail call
let values = if let Some(expr) = maybe_expr { let values = if let Some(expr) = maybe_expr {
vec![self.translate_expr(expr)] vec![self.translate_expr(&expr.expr)]
} else { } else {
// XXX: urgh // XXX: urgh
Vec::with_capacity(0) Vec::with_capacity(0)
@ -269,12 +304,16 @@ impl<'a> FunctionTranslator<'a> {
None None
} }
Statement::CallStatement(call) => self.translate_call(call), Statement::CallStatement { call, .. } => self.translate_call(call),
Statement::UseStatement(_) => todo!(), Statement::UseStatement { .. } => todo!(),
Statement::IfStatement(cond, then_body) => { Statement::IfStatement {
let condition_value = self.translate_expr(cond); condition: cond,
then_block: then_body,
..
} => {
let condition_value = self.translate_expr(&cond.expr);
let then_block = self.builder.create_block(); let then_block = self.builder.create_block();
let merge_block = self.builder.create_block(); let merge_block = self.builder.create_block();
@ -294,7 +333,7 @@ impl<'a> FunctionTranslator<'a> {
None None
} }
Statement::WhileStatement(_, _) => todo!(), Statement::WhileStatement { .. } => todo!(),
} }
} }
@ -302,11 +341,11 @@ impl<'a> FunctionTranslator<'a> {
match expr { match expr {
Expr::UnitLiteral => unreachable!(), Expr::UnitLiteral => unreachable!(),
Expr::BooleanLiteral(imm) => self.builder.ins().iconst(types::I8, i64::from(*imm)), Expr::BooleanLiteral(imm, ..) => self.builder.ins().iconst(types::I8, i64::from(*imm)),
Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm), Expr::IntegerLiteral(imm, ..) => self.builder.ins().iconst(types::I32, *imm),
Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm), Expr::FloatLiteral(imm, ..) => self.builder.ins().f64const(*imm),
Expr::StringLiteral(s) => { Expr::StringLiteral(s, ..) => {
let id = self.module.declare_anonymous_data(false, false).unwrap(); let id = self.module.declare_anonymous_data(false, false).unwrap();
let bytes: Box<[u8]> = s.as_bytes().into(); let bytes: Box<[u8]> = s.as_bytes().into();
self.data_desc.define(bytes); self.data_desc.define(bytes);
@ -318,14 +357,9 @@ impl<'a> FunctionTranslator<'a> {
.global_value(self.module.isa().pointer_type(), gv) .global_value(self.module.isa().pointer_type(), gv)
} }
Expr::BinaryExpression { Expr::BinaryExpression(BinaryExpression { lhs, rhs, op, .. }) => {
lhs, let lhs_value = self.translate_expr(&lhs.expr);
op, let rhs_value = self.translate_expr(&rhs.expr);
rhs,
typ: _,
} => {
let lhs_value = self.translate_expr(lhs);
let rhs_value = self.translate_expr(rhs);
match (lhs.ty(), lhs.ty()) { match (lhs.ty(), lhs.ty()) {
(Type::Int, Type::Int) => match op { (Type::Int, Type::Int) => match op {
@ -361,8 +395,9 @@ impl<'a> FunctionTranslator<'a> {
then_body, then_body,
else_body, else_body,
typ, typ,
..
} => { } => {
let condition_value = self.translate_expr(cond); let condition_value = self.translate_expr(&cond.expr);
let then_block = self.builder.create_block(); let then_block = self.builder.create_block();
let else_block = self.builder.create_block(); let else_block = self.builder.create_block();
@ -384,11 +419,11 @@ impl<'a> FunctionTranslator<'a> {
self.builder.switch_to_block(then_block); self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block); self.builder.seal_block(then_block);
for stmt in &then_body.statements { for stmt in &then_body.statements {
self.translate_statement(&stmt); self.translate_statement(stmt);
} }
let then_return_value = match &then_body.value { let then_return_value = match &then_body.value {
Some(val) => vec![self.translate_expr(val)], Some(val) => vec![self.translate_expr(&val.expr)],
None => Vec::with_capacity(0), None => Vec::with_capacity(0),
}; };
@ -399,9 +434,9 @@ impl<'a> FunctionTranslator<'a> {
self.builder.seal_block(else_block); self.builder.seal_block(else_block);
// XXX: the else can be just an expression: do we always need to // XXX: the else can be just an expression: do we always need to
// make a second branch in that case? Or leave it to cranelift? // make a second branch in that case? Or leave it to cranelift?
let else_return_value = match **else_body { let else_return_value = match else_body.expr {
Expr::UnitLiteral => Vec::with_capacity(0), Expr::UnitLiteral => Vec::with_capacity(0),
_ => vec![self.translate_expr(else_body)], _ => vec![self.translate_expr(&else_body.expr)],
}; };
// Jump to the merge block, passing it the block return value. // Jump to the merge block, passing it the block return value.
@ -420,8 +455,8 @@ impl<'a> FunctionTranslator<'a> {
phi phi
} }
Expr::UnaryExpression { op, inner } => { Expr::UnaryExpression { op, inner, .. } => {
let inner_value = self.translate_expr(inner); let inner_value = self.translate_expr(&inner.expr);
match op { match op {
// XXX: This should not be a literal translation // XXX: This should not be a literal translation
UnaryOperator::Not => { UnaryOperator::Not => {
@ -431,13 +466,13 @@ impl<'a> FunctionTranslator<'a> {
} }
} }
Expr::Identifier { name, typ: _ } => { Expr::Identifier { name, .. } => {
self.builder.use_var(*self.variables.get(name).unwrap()) self.builder.use_var(*self.variables.get(name).unwrap())
} }
Expr::Call(call) => self.translate_call(call).unwrap(), Expr::Call(call, ..) => self.translate_call(call).unwrap(),
Expr::Block(block) => self.translate_block(block).unwrap(), Expr::Block(block, ..) => self.translate_block(block).unwrap(),
} }
} }
@ -445,16 +480,16 @@ impl<'a> FunctionTranslator<'a> {
for stmt in &block.statements { for stmt in &block.statements {
self.translate_statement(stmt); self.translate_statement(stmt);
} }
if let Some(block_value) = &block.value {
Some(self.translate_expr(block_value)) block
} else { .value
None .as_ref()
} .map(|block_value| self.translate_expr(&block_value.expr))
} }
fn translate_call(&mut self, call: &ast::Call) -> Option<Value> { fn translate_call(&mut self, call: &ast::Call) -> Option<Value> {
match call.callee.deref() { match &call.callee.expr {
Expr::Identifier { name, typ: _ } => { Expr::Identifier { name, .. } => {
let func_ref = if let Some(func_or_data_id) = self.module.get_name(name.as_ref()) { let func_ref = if let Some(func_or_data_id) = self.module.get_name(name.as_ref()) {
if let FuncOrDataId::Func(func_id) = func_or_data_id { if let FuncOrDataId::Func(func_id) = func_or_data_id {
self.module.declare_func_in_func(func_id, self.builder.func) self.module.declare_func_in_func(func_id, self.builder.func)
@ -465,7 +500,11 @@ impl<'a> FunctionTranslator<'a> {
todo!() todo!()
}; };
let args: Vec<Value> = call.args.iter().map(|a| self.translate_expr(a)).collect(); let args: Vec<Value> = call
.args
.iter()
.map(|a| self.translate_expr(&a.expr))
.collect();
let call_inst = self.builder.ins().call(func_ref, &args); let call_inst = self.builder.ins().call(func_ref, &args);
let results = self.builder.inst_results(call_inst); let results = self.builder.inst_results(call_inst);

View file

@ -1,14 +1,20 @@
pub mod ast; pub mod ast;
pub mod jit; pub mod jit;
pub mod parsing; pub mod parsing;
pub mod source;
pub mod typing; pub mod typing;
use clap::{Parser, Subcommand}; use std::default::Default;
use std::path::PathBuf;
use clap::{Parser as ClapParser, Subcommand};
use crate::ast::Module; use crate::ast::Module;
use crate::parsing::Parser;
use crate::source::SourceCache;
/// Experimental compiler for lila /// Experimental compiler for lila
#[derive(Parser, Debug)] #[derive(ClapParser, Debug)]
#[command(author = "Romain P. <rpqt@rpqt.fr>")] #[command(author = "Romain P. <rpqt@rpqt.fr>")]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
struct Cli { struct Cli {
@ -51,21 +57,27 @@ enum Commands {
}, },
} }
fn parse(files: &Vec<String>) -> Vec<Module> { fn parse(files: &[String]) -> Vec<Module> {
let mut parser = parsing::DefaultParser::default();
let paths = files.iter().map(std::path::Path::new); let paths = files.iter().map(std::path::Path::new);
paths paths
.map(|path| match parsing::parse_file(&path) { .enumerate()
.map(|(i, path)| match parser.parse_file(path, i as u32) {
Ok(module) => module, Ok(module) => module,
Err(e) => panic!("Parsing error: {:#?}", e), Err(e) => panic!("Parsing error: {:#?}", e),
}) })
.collect() .collect()
} }
fn check(modules: &mut Vec<Module>) { fn check(modules: &mut Vec<Module>, source_cache: &mut SourceCache) {
while let Some(module) = modules.pop() { for module in modules {
if let Err(e) = module.type_check() { if let Err(errors) = module.type_check() {
eprintln!("{}", e); for error in errors {
return; error
.to_report(module)
.eprint(&mut *source_cache)
.expect("cannot write error to stderr");
}
} }
} }
} }
@ -83,20 +95,32 @@ fn main() {
} }
println!("Parsing OK"); println!("Parsing OK");
} }
Commands::TypeCheck { files, dump_ast } => { Commands::TypeCheck { files, dump_ast } => {
let mut source_cache = SourceCache {
paths: files.iter().map(PathBuf::from).collect(),
file_cache: ariadne::FileCache::default(),
};
let mut modules = parse(files); let mut modules = parse(files);
check(&mut modules); check(&mut modules, &mut source_cache);
if *dump_ast { if *dump_ast {
for module in &modules { for module in &modules {
println!("{:#?}", &module); println!("{:#?}", &module);
} }
} }
} }
Commands::Compile { files, dump_clir } | Commands::Run { files, dump_clir } => { Commands::Compile { files, dump_clir } | Commands::Run { files, dump_clir } => {
let mut source_cache = SourceCache {
paths: files.iter().map(PathBuf::from).collect(),
file_cache: ariadne::FileCache::default(),
};
let mut jit = jit::JIT::default(); let mut jit = jit::JIT::default();
jit.dump_clir = *dump_clir; jit.dump_clir = *dump_clir;
for file in files {
match jit.compile_file(file) { for (id, file) in files.iter().enumerate() {
match jit.compile_file(file, id as u32, &mut source_cache) {
Err(e) => eprintln!("{}", e), Err(e) => eprintln!("{}", e),
Ok(code) => { Ok(code) => {
println!("Compiled {}", file); println!("Compiled {}", file);

View file

@ -1,14 +1,13 @@
use std::fs; use std::fs;
use std::path::Path; use std::path::Path;
use pest::error::Error; use expr::BinaryExpression;
use pest::iterators::Pair; use pest::iterators::Pair;
use pest::pratt_parser::PrattParser; use pest::pratt_parser::PrattParser;
use pest::Parser; use pest::Parser as PestParser;
use ReturnStatement;
use crate::ast::Module;
use crate::ast::*; use crate::ast::*;
use crate::ast::{Import, ModulePath};
use crate::typing::Type; use crate::typing::Type;
#[derive(pest_derive::Parser)] #[derive(pest_derive::Parser)]
@ -33,24 +32,38 @@ lazy_static::lazy_static! {
}; };
} }
pub fn parse_file(path: &Path) -> Result<Module, Error<Rule>> { #[derive(Default)]
let source = fs::read_to_string(&path).expect("could not read source file"); pub struct Parser {
source: SourceId,
}
impl crate::parsing::Parser for Parser {
fn parse_file(&mut self, path: &Path, id: SourceId) -> anyhow::Result<Module> {
let source = fs::read_to_string(path)?;
let module_path = ModulePath::from(path); let module_path = ModulePath::from(path);
let mut module = parse_as_module(&source, module_path)?; let mut module = self.parse_as_module(&source, module_path, id)?;
module.file = Some(path.to_owned()); module.file = Some(path.to_owned());
Ok(module) Ok(module)
} }
pub fn parse_as_module(source: &str, path: ModulePath) -> Result<Module, Error<Rule>> { fn parse_as_module(
let mut pairs = LilaParser::parse(Rule::source_file, &source)?; &mut self,
source: &str,
path: ModulePath,
id: SourceId,
) -> anyhow::Result<Module> {
self.source = id;
let mut pairs = LilaParser::parse(Rule::source_file, source)?;
assert!(pairs.len() == 1); assert!(pairs.len() == 1);
let module = parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path); let module = self.parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path);
Ok(module) Ok(module)
} }
}
pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module { impl Parser {
fn parse_module(&self, pair: Pair<Rule>, path: ModulePath) -> Module {
assert!(pair.as_rule() == Rule::module_items); assert!(pair.as_rule() == Rule::module_items);
let mut module = Module::new(path); let mut module = Module::new(path);
@ -59,13 +72,13 @@ pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
for pair in pairs { for pair in pairs {
match pair.as_rule() { match pair.as_rule() {
Rule::definition => { Rule::definition => {
let def = parse_definition(pair.into_inner().next().unwrap()); let def = self.parse_definition(pair.into_inner().next().unwrap());
match def { match def {
Definition::FunctionDefinition(func) => module.functions.push(func), Definition::FunctionDefinition(func) => module.functions.push(func),
} }
} }
Rule::use_statement => { Rule::use_statement => {
let path = parse_import(pair.into_inner().next().unwrap()); let path = self.parse_import(pair.into_inner().next().unwrap());
module.imports.push(path); module.imports.push(path);
} }
_ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()), _ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()),
@ -75,14 +88,15 @@ pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
module module
} }
fn parse_block(pair: Pair<Rule>) -> Block { fn parse_block(&self, pair: Pair<Rule>) -> Block {
let mut statements = vec![]; let mut statements = vec![];
let mut value = None; let mut value = None;
let span = self.make_span(&pair);
for pair in pair.into_inner() { for pair in pair.into_inner() {
match pair.as_rule() { match pair.as_rule() {
Rule::statement => statements.push(parse_statement(pair)), Rule::statement => statements.push(self.parse_statement(pair)),
Rule::expr => value = Some(parse_expression(pair)), Rule::expr => value = Some(self.parse_expression(pair)),
_ => panic!("unexpected rule {:?} in block", pair.as_rule()), _ => panic!("unexpected rule {:?} in block", pair.as_rule()),
} }
} }
@ -91,75 +105,104 @@ fn parse_block(pair: Pair<Rule>) -> Block {
statements, statements,
value, value,
typ: Type::Undefined, typ: Type::Undefined,
span: Some(span),
} }
} }
fn parse_statement(pair: Pair<Rule>) -> Statement { fn parse_statement(&self, pair: Pair<Rule>) -> Statement {
let pair = pair.into_inner().next().unwrap(); let pair = pair.into_inner().next().unwrap();
let span = self.make_span(&pair);
match pair.as_rule() { match pair.as_rule() {
Rule::assign_statement => { Rule::assign_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string(); let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap()); let expr = self.parse_expression(pairs.next().unwrap());
Statement::AssignStatement(identifier, Box::new(expr)) Statement::AssignStatement {
lhs: identifier,
rhs: Box::new(expr),
span,
}
} }
Rule::declare_statement => { Rule::declare_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string(); let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap()); let expr = self.parse_expression(pairs.next().unwrap());
Statement::DeclareStatement(identifier, Box::new(expr)) Statement::DeclareStatement {
lhs: identifier,
rhs: Box::new(expr),
span,
}
} }
Rule::return_statement => { Rule::return_statement => {
let expr = if let Some(pair) = pair.into_inner().next() { let expr = pair
Some(parse_expression(pair)) .into_inner()
} else { .next()
None .map(|expr| self.parse_expression(expr));
}; Statement::ReturnStatement(ReturnStatement { expr, span })
Statement::ReturnStatement(expr)
} }
Rule::call_statement => { Rule::call_statement => {
let call = parse_call(pair.into_inner().next().unwrap()); let call = self.parse_call(pair.into_inner().next().unwrap());
Statement::CallStatement(Box::new(call)) Statement::CallStatement {
call: Box::new(call),
span,
}
} }
Rule::use_statement => { Rule::use_statement => {
let import = parse_import(pair.into_inner().next().unwrap()); let import = self.parse_import(pair.into_inner().next().unwrap());
Statement::UseStatement(Box::new(import)) Statement::UseStatement {
import: Box::new(import),
span,
}
} }
Rule::if_statement => { Rule::if_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap()); let condition = self.parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap()); let block = self.parse_block(pairs.next().unwrap());
if pairs.next().is_some() { if pairs.next().is_some() {
todo!("implement if-statements with else branch (and else if)") todo!("implement if-statements with else branch (and else if)")
} }
Statement::IfStatement(Box::new(condition), Box::new(block)) Statement::IfStatement {
condition: Box::new(condition),
then_block: Box::new(block),
span,
}
} }
Rule::while_statement => { Rule::while_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap()); let condition = self.parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap()); let block = self.parse_block(pairs.next().unwrap());
Statement::WhileStatement(Box::new(condition), Box::new(block)) Statement::WhileStatement {
condition: Box::new(condition),
loop_block: Box::new(block),
span,
}
} }
_ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()), _ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()),
} }
} }
fn parse_import(pair: Pair<Rule>) -> Import { fn parse_import(&self, pair: Pair<Rule>) -> Import {
Import(pair.as_str().to_string()) Import(pair.as_str().to_string())
} }
fn parse_call(pair: Pair<Rule>) -> Call { fn parse_call(&self, pair: Pair<Rule>) -> Call {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
// TODO: support calls on more than identifiers (needs grammar change) // TODO: support calls on more than identifiers (needs grammar change)
let callee = Expr::Identifier {
name: pairs.next().unwrap().as_str().to_string(), let pair = pairs.next().unwrap();
let callee = SExpr {
expr: Expr::Identifier {
name: pair.as_str().to_string(),
typ: Type::Undefined, typ: Type::Undefined,
},
span: self.make_span(&pair),
}; };
let args: Vec<Expr> = pairs let args: Vec<SExpr> = pairs
.next() .next()
.unwrap() .unwrap()
.into_inner() .into_inner()
.map(parse_expression) .map(|arg| self.parse_expression(arg))
.collect(); .collect();
Call { Call {
@ -169,13 +212,25 @@ fn parse_call(pair: Pair<Rule>) -> Call {
} }
} }
fn parse_expression(pair: Pair<Rule>) -> Expr { fn parse_expression(&self, pair: Pair<Rule>) -> SExpr {
let span = self.make_span(&pair);
let pairs = pair.into_inner(); let pairs = pair.into_inner();
PRATT_PARSER let mut map = PRATT_PARSER
.map_primary(|primary| match primary.as_rule() { .map_primary(|primary| {
Rule::integer_literal => Expr::IntegerLiteral(primary.as_str().parse().unwrap()), let span = self.make_span(&primary);
Rule::float_literal => Expr::FloatLiteral(primary.as_str().parse().unwrap()), match primary.as_rule() {
Rule::string_literal => Expr::StringLiteral( Rule::integer_literal => SExpr {
expr: Expr::IntegerLiteral(primary.as_str().parse().unwrap()),
span,
},
Rule::float_literal => SExpr {
expr: Expr::FloatLiteral(primary.as_str().parse().unwrap()),
span,
},
Rule::string_literal => SExpr {
expr: Expr::StringLiteral(
primary primary
.into_inner() .into_inner()
.next() .next()
@ -184,34 +239,59 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
.parse() .parse()
.unwrap(), .unwrap(),
), ),
Rule::expr => parse_expression(primary), span,
Rule::ident => Expr::Identifier { },
Rule::expr => self.parse_expression(primary),
Rule::ident => SExpr {
expr: Expr::Identifier {
name: primary.as_str().to_string(), name: primary.as_str().to_string(),
typ: Type::Undefined, typ: Type::Undefined,
}, },
Rule::call => Expr::Call(Box::new(parse_call(primary))), span,
Rule::block => Expr::Block(Box::new(parse_block(primary))), },
Rule::call => SExpr {
expr: Expr::Call(Box::new(self.parse_call(primary))),
span,
},
Rule::block => SExpr {
expr: Expr::Block(Box::new(self.parse_block(primary))),
span,
},
Rule::if_expr => { Rule::if_expr => {
let mut pairs = primary.into_inner(); let mut pairs = primary.into_inner();
let condition = parse_expression(pairs.next().unwrap()); let condition = self.parse_expression(pairs.next().unwrap());
let true_block = parse_block(pairs.next().unwrap()); let true_block = self.parse_block(pairs.next().unwrap());
let else_value = parse_expression(pairs.next().unwrap()); let else_value = self.parse_expression(pairs.next().unwrap());
Expr::IfExpr { SExpr {
expr: Expr::IfExpr {
cond: Box::new(condition), cond: Box::new(condition),
then_body: Box::new(true_block), then_body: Box::new(true_block),
else_body: Box::new(else_value), else_body: Box::new(else_value),
typ: Type::Undefined, typ: Type::Undefined,
},
span,
} }
} }
Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() {
Rule::boolean_literal => SExpr {
expr: Expr::BooleanLiteral(match primary.as_str() {
"true" => true, "true" => true,
"false" => false, "false" => false,
_ => unreachable!(), _ => unreachable!(),
}), }),
span,
},
_ => unreachable!( _ => unreachable!(
"Unexpected rule '{:?}' in primary expression", "Unexpected rule '{:?}' in primary expression",
primary.as_rule() primary.as_rule()
), ),
}
}) })
.map_infix(|lhs, op, rhs| { .map_infix(|lhs, op, rhs| {
let operator = match op.as_rule() { let operator = match op.as_rule() {
@ -226,27 +306,30 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
Rule::or => BinaryOperator::Or, Rule::or => BinaryOperator::Or,
_ => unreachable!(), _ => unreachable!(),
}; };
Expr::BinaryExpression { let expr = Expr::BinaryExpression(BinaryExpression {
lhs: Box::new(lhs), lhs: Box::new(lhs),
op: operator, op: operator,
op_span: self.make_span(&op),
rhs: Box::new(rhs), rhs: Box::new(rhs),
typ: Type::Undefined, typ: Type::Undefined,
} });
SExpr { expr, span }
}) })
.map_prefix(|op, inner| { .map_prefix(|op, inner| {
let operator = match op.as_rule() { let operator = match op.as_rule() {
Rule::not => UnaryOperator::Not, Rule::not => UnaryOperator::Not,
_ => unreachable!(), _ => unreachable!(),
}; };
Expr::UnaryExpression { let expr = Expr::UnaryExpression {
op: operator, op: operator,
inner: Box::new(inner), inner: Box::new(inner),
} };
}) SExpr { expr, span }
.parse(pairs) });
map.parse(pairs)
} }
fn parse_parameter(pair: Pair<Rule>) -> Parameter { fn parse_parameter(&self, pair: Pair<Rule>) -> Parameter {
assert!(pair.as_rule() == Rule::parameter); assert!(pair.as_rule() == Rule::parameter);
let mut pair = pair.into_inner(); let mut pair = pair.into_inner();
let name = pair.next().unwrap().as_str().to_string(); let name = pair.next().unwrap().as_str().to_string();
@ -254,38 +337,53 @@ fn parse_parameter(pair: Pair<Rule>) -> Parameter {
Parameter { name, typ } Parameter { name, typ }
} }
fn parse_definition(pair: Pair<Rule>) -> Definition { fn parse_definition(&self, pair: Pair<Rule>) -> Definition {
match pair.as_rule() { match pair.as_rule() {
Rule::func_def => { Rule::func_def => {
let line_col = pair.line_col(); let span = self.make_span(&pair);
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let name = pairs.next().unwrap().as_str().to_string(); let name = pairs.next().unwrap().as_str().to_string();
let parameters: Vec<Parameter> = pairs let parameters: Vec<Parameter> = pairs
.next() .next()
.unwrap() .unwrap()
.into_inner() .into_inner()
.map(parse_parameter) .map(|param| self.parse_parameter(param))
.collect(); .collect();
let pair = pairs.next().unwrap(); let pair = pairs.next().unwrap();
// Before the block there is an optional return type // Before the block there is an optional return type
let (return_type, pair) = match pair.as_rule() { let (return_type, return_type_span, pair) = match pair.as_rule() {
Rule::ident => (Some(Type::from(pair.as_str())), pairs.next().unwrap()), Rule::ident => (
Rule::block => (None, pair), Some(Type::from(pair.as_str())),
Some(self.make_span(&pair)),
pairs.next().unwrap(),
),
Rule::block => (None, None, pair),
_ => unreachable!( _ => unreachable!(
"Unexpected rule '{:?}' in function definition, expected return type or block", "Unexpected rule '{:?}' in function definition, expected return type or block",
pair.as_rule() pair.as_rule()
), ),
}; };
let body = parse_block(pair); let body = self.parse_block(pair);
let body = Box::new(body); let body = Box::new(body);
Definition::FunctionDefinition(FunctionDefinition { Definition::FunctionDefinition(FunctionDefinition {
name, name,
parameters, parameters,
return_type, return_type,
return_type_span,
span,
body, body,
location: Location { line_col },
}) })
} }
_ => panic!("unexpected node for definition: {:?}", pair.as_rule()), _ => panic!("unexpected node for definition: {:?}", pair.as_rule()),
} }
} }
fn make_span(&self, pair: &Pair<Rule>) -> Span {
let span = pair.as_span();
Span {
source: self.source,
start: span.start(),
end: span.end(),
}
}
}

View file

@ -1,4 +1,18 @@
mod backend; mod backend;
mod tests; mod tests;
pub use self::backend::pest::{parse_file, parse_module, parse_as_module}; use crate::ast::{Module, ModulePath, SourceId};
pub trait Parser: Default {
fn parse_file(&mut self, path: &std::path::Path, id: SourceId) -> anyhow::Result<Module>;
fn parse_as_module(
&mut self,
source: &str,
path: ModulePath,
id: SourceId,
) -> anyhow::Result<Module>;
}
pub use self::backend::pest::Parser as PestParser;
pub use PestParser as DefaultParser;

View file

@ -1,12 +1,17 @@
#[cfg(test)]
use pretty_assertions::assert_eq;
#[test] #[test]
fn test_addition_function() { fn test_addition_function() {
use crate::ast::{expr::Expr, *}; use crate::ast::*;
use crate::parsing::backend::pest::parse_as_module; use crate::parsing::*;
use crate::typing::Type; use crate::typing::*;
let source = "fn add(a: int, b: int) int { a + b }"; let source = "fn add(a: int, b: int) int { a + b }";
let path = ModulePath::from("test"); let path = ModulePath::from("test");
let module = parse_as_module(&source, path.clone()).expect("parsing error"); let module = DefaultParser::default()
.parse_as_module(source, path.clone(), 0)
.expect("parsing error");
let expected_module = Module { let expected_module = Module {
file: None, file: None,
@ -26,21 +31,61 @@ fn test_addition_function() {
return_type: Some(Type::Int), return_type: Some(Type::Int),
body: Box::new(Block { body: Box::new(Block {
statements: vec![], statements: vec![],
value: Some(Expr::BinaryExpression { value: Some(SExpr {
lhs: Box::new(Expr::Identifier { expr: Expr::BinaryExpression(BinaryExpression {
lhs: Box::new(SExpr {
expr: Expr::Identifier {
name: Identifier::from("a"), name: Identifier::from("a"),
typ: Type::Undefined, typ: Type::Undefined,
},
span: Span {
source: 0,
start: 29,
end: 30,
},
}), }),
op: BinaryOperator::Add, op: BinaryOperator::Add,
rhs: Box::new(Expr::Identifier { op_span: Span {
source: 0,
start: 31,
end: 32,
},
rhs: Box::new(SExpr {
expr: Expr::Identifier {
name: Identifier::from("b"), name: Identifier::from("b"),
typ: Type::Undefined, typ: Type::Undefined,
},
span: Span {
source: 0,
start: 33,
end: 34,
},
}), }),
typ: Type::Undefined, typ: Type::Undefined,
}), }),
typ: Type::Undefined, span: Span {
source: 0,
start: 29,
end: 34,
},
}),
typ: Type::Undefined,
span: Some(Span {
source: 0,
start: 27,
end: source.len(),
}),
}),
span: Span {
source: 0,
start: 0,
end: source.len(),
},
return_type_span: Some(Span {
source: 0,
start: 23,
end: 26,
}), }),
location: Location { line_col: (1, 1) },
}], }],
path, path,
}; };

22
src/source.rs Normal file
View file

@ -0,0 +1,22 @@
use crate::ast::SourceId;
use ariadne::FileCache;
pub struct SourceCache {
pub paths: Vec<std::path::PathBuf>,
pub file_cache: FileCache,
}
impl ariadne::Cache<SourceId> for SourceCache {
type Storage = String;
fn fetch(
&mut self,
id: &SourceId,
) -> Result<&ariadne::Source<Self::Storage>, Box<dyn std::fmt::Debug + '_>> {
self.file_cache.fetch(&self.paths[*id as usize])
}
fn display<'a>(&self, id: &'a SourceId) -> Option<Box<dyn std::fmt::Display + 'a>> {
Some(Box::new(format!("{}", self.paths[*id as usize].display())))
}
}

View file

@ -1,8 +1,10 @@
use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypingContext}; use ariadne::{ColorGenerator, Fmt, Label, Report, ReportKind, Span as _};
use std::fmt::Debug;
use super::UnaryOperator; use super::{Span, UnaryOperator};
use crate::typing::{BinaryOperator, Identifier, ModulePath, Type};
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct TypeError { pub struct TypeError {
pub file: Option<std::path::PathBuf>, pub file: Option<std::path::PathBuf>,
pub module: ModulePath, pub module: ModulePath,
@ -10,72 +12,31 @@ pub struct TypeError {
pub kind: TypeErrorKind, pub kind: TypeErrorKind,
} }
impl std::fmt::Display for TypeError { #[derive(PartialEq, Debug)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { pub struct TypeAndSpan {
f.write_str("Error\n")?; pub ty: Type,
if let Some(path) = &self.file { pub span: Span,
f.write_fmt(format_args!(" in file {}\n", path.display()))?;
}
f.write_fmt(format_args!(" in module {}\n", self.module))?;
if let Some(name) = &self.function {
f.write_fmt(format_args!(" in function {}\n", name))?;
}
f.write_fmt(format_args!("{:#?}", self.kind))?;
Ok(())
}
} }
#[derive(Default)] #[derive(PartialEq, Debug)]
pub struct TypeErrorBuilder { pub struct BinOpAndSpan {
file: Option<std::path::PathBuf>, pub op: BinaryOperator,
module: Option<ModulePath>, pub span: Span,
function: Option<String>,
kind: Option<TypeErrorKind>,
}
impl TypeError {
pub fn builder() -> TypeErrorBuilder {
TypeErrorBuilder::default()
}
}
impl TypeErrorBuilder {
pub fn context(mut self, ctx: &TypingContext) -> Self {
self.file = ctx.file.clone();
self.module = Some(ctx.module.clone());
self.function = ctx.function.clone();
self
}
pub fn kind(mut self, kind: TypeErrorKind) -> Self {
self.kind = Some(kind);
self
}
pub fn build(self) -> TypeError {
TypeError {
file: self.file,
module: self.module.expect("TypeError builder is missing module"),
function: self.function,
kind: self.kind.expect("TypeError builder is missing kind"),
}
}
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
InvalidBinaryOperator { InvalidBinaryOperator {
operator: BinaryOperator, operator: BinOpAndSpan,
lht: Type, lhs: TypeAndSpan,
rht: Type, rhs: TypeAndSpan,
}, },
BlockTypeDoesNotMatchFunctionType { BlockTypeDoesNotMatchFunctionType {
block_type: Type, block_type: Type,
function_type: Type,
}, },
ReturnTypeDoesNotMatchFunctionType { ReturnTypeDoesNotMatchFunctionType {
function_type: Type, return_expr: Option<TypeAndSpan>,
return_type: Type, return_stmt: TypeAndSpan,
}, },
UnknownIdentifier { UnknownIdentifier {
identifier: String, identifier: String,
@ -86,7 +47,6 @@ pub enum TypeErrorKind {
}, },
AssignUndeclared, AssignUndeclared,
VariableRedeclaration, VariableRedeclaration,
ReturnStatementsMismatch,
UnknownFunctionCalled(Identifier), UnknownFunctionCalled(Identifier),
WrongFunctionArguments, WrongFunctionArguments,
ConditionIsNotBool, ConditionIsNotBool,
@ -96,3 +56,132 @@ pub enum TypeErrorKind {
inner: Type, inner: Type,
}, },
} }
impl TypeError {
pub fn to_report(&self, ast: &crate::ast::Module) -> Report<Span> {
let mut colors = ColorGenerator::new();
let c0 = colors.next();
let c1 = colors.next();
colors.next();
let c2 = colors.next();
match &self.kind {
TypeErrorKind::InvalidBinaryOperator { operator, lhs, rhs } => {
Report::build(ReportKind::Error, 0u32, 0)
.with_message(format!(
"Invalid binary operation {} between {} and {}",
operator.op.to_string().fg(c0),
lhs.ty.to_string().fg(c1),
rhs.ty.to_string().fg(c2),
))
.with_labels([
Label::new(operator.span).with_color(c0),
Label::new(lhs.span)
.with_message(format!("This has type {}", lhs.ty.to_string().fg(c1)))
.with_color(c1)
.with_order(2),
Label::new(rhs.span)
.with_message(format!("This has type {}", rhs.ty.to_string().fg(c2)))
.with_color(c2)
.with_order(1),
])
.finish()
}
TypeErrorKind::BlockTypeDoesNotMatchFunctionType { block_type } => {
let function = ast
.functions
.iter()
.find(|f| f.name == *self.function.as_ref().unwrap())
.unwrap();
let block_color = c0;
let signature_color = c1;
let span = function.body.value.as_ref().unwrap().span;
let report = Report::build(ReportKind::Error, 0u32, 0)
.with_message("Function body does not match the signature")
.with_labels([
Label::new(function.body.span.unwrap())
.with_message("In this function's body")
.with_color(c2),
Label::new(span)
.with_message(format!(
"Returned expression has type {} but the function should return {}",
block_type.to_string().fg(block_color),
function
.return_type
.as_ref()
.unwrap_or(&Type::Unit)
.to_string()
.fg(signature_color)
))
.with_color(block_color),
]);
let report =
report.with_note("The last expression of a function's body is returned");
let report = if function.return_type.is_none() {
report.with_help(
"You may need to add the return type to the function's signature",
)
} else {
report
};
report.finish()
}
TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
return_expr,
return_stmt,
} => {
let function = ast
.functions
.iter()
.find(|f| f.name == *self.function.as_ref().unwrap())
.unwrap();
let is_bare_return = return_expr.is_none();
let report = Report::build(ReportKind::Error, *return_stmt.span.source(), 0)
.with_message("Return type does not match the function's signature")
.with_label(
Label::new(return_expr.as_ref().unwrap_or(return_stmt).span)
.with_color(c1)
.with_message(if is_bare_return {
format!("Bare return has type {}", Type::Unit.to_string().fg(c1))
} else {
format!(
"This expression has type {}",
return_stmt.ty.to_string().fg(c1)
)
}),
);
let report = if let Some(ret_ty_span) = function.return_type_span {
report.with_label(Label::new(ret_ty_span).with_color(c0).with_message(format!(
"The signature shows {}",
function.return_type.as_ref().unwrap().to_string().fg(c0)
)))
} else {
report
};
let report = if function.return_type.is_none() {
report.with_help(
"You may need to add the return type to the function's signature",
)
} else {
report
};
report.finish()
}
_ => todo!(),
}
}
}

View file

@ -1,11 +1,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use BinaryExpression;
use ReturnStatement;
use crate::ast::ModulePath; use crate::ast::ModulePath;
use crate::ast::*; use crate::ast::*;
mod error; mod error;
use crate::typing::error::{TypeError, TypeErrorKind}; use crate::typing::error::{TypeAndSpan, TypeError, TypeErrorKind};
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@ -62,11 +65,11 @@ impl From<&str> for Type {
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct Signature(Vec<Type>, Type); pub struct Signature(Vec<Type>, Type);
impl Into<Type> for Signature { impl From<Signature> for Type {
fn into(self) -> Type { fn from(val: Signature) -> Self {
Type::Function { Type::Function {
params: self.0, params: val.0,
returns: Box::new(self.1), returns: Box::new(val.1),
} }
} }
} }
@ -79,12 +82,13 @@ impl FunctionDefinition {
} }
} }
#[derive(Debug, PartialEq)]
pub struct CheckedModule(pub Module); pub struct CheckedModule(pub Module);
impl Module { impl Module {
pub fn type_check(mut self) -> Result<CheckedModule, TypeError> { pub fn type_check(&mut self) -> Result<(), Vec<TypeError>> {
let mut ctx = TypingContext::new(self.path.clone()); let mut ctx = TypingContext::new(self.path.clone());
ctx.file = self.file.clone(); ctx.file.clone_from(&self.file);
// Register all function signatures // Register all function signatures
for func in &self.functions { for func in &self.functions {
@ -95,13 +99,21 @@ impl Module {
// TODO: add signatures of imported functions (even if they have not been checked) // TODO: add signatures of imported functions (even if they have not been checked)
let mut errors = Vec::new();
// Type-check the function bodies and complete all type placeholders // Type-check the function bodies and complete all type placeholders
for func in &mut self.functions { for func in &mut self.functions {
func.typ(&mut ctx)?; if let Err(e) = func.typ(&mut ctx) {
errors.push(e);
};
ctx.variables.clear(); ctx.variables.clear();
} }
Ok(CheckedModule(self)) if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
} }
} }
@ -128,6 +140,15 @@ impl TypingContext {
variables: Default::default(), variables: Default::default(),
} }
} }
pub fn make_error(&self, kind: TypeErrorKind) -> TypeError {
TypeError {
kind,
file: self.file.clone(),
module: self.module.clone(),
function: self.function.clone(),
}
}
} }
/// Trait for nodes which have a deducible type. /// Trait for nodes which have a deducible type.
@ -148,40 +169,13 @@ impl TypeCheck for FunctionDefinition {
let body_type = self.body.typ(ctx)?; let body_type = self.body.typ(ctx)?;
// If the return type is not specified, it is unit.
if self.return_type.is_none() {
self.return_type = Some(Type::Unit)
}
// Check coherence with the body's type. // Check coherence with the body's type.
if *self.return_type.as_ref().unwrap() != body_type { if *self.return_type.as_ref().unwrap_or(&Type::Unit) != body_type {
return Err(TypeError::builder() return Err(
.context(ctx) ctx.make_error(TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
.kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
block_type: body_type.clone(), block_type: body_type.clone(),
function_type: self.return_type.as_ref().unwrap().clone(), }),
}) );
.build());
}
// Check coherence with return statements.
for statement in &mut self.body.statements {
if let Statement::ReturnStatement(value) = statement {
let ret_type = match value {
Some(expr) => expr.typ(ctx)?,
None => Type::Unit,
};
if ret_type != *self.return_type.as_ref().unwrap() {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
function_type: self.return_type.as_ref().unwrap().clone(),
return_type: ret_type.clone(),
})
.build());
}
}
} }
Ok(self.return_type.clone().unwrap()) Ok(self.return_type.clone().unwrap())
@ -190,79 +184,65 @@ impl TypeCheck for FunctionDefinition {
impl TypeCheck for Block { impl TypeCheck for Block {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
let mut return_typ: Option<Type> = None;
// Check declarations and assignments. // Check declarations and assignments.
for statement in &mut self.statements { for statement in &mut self.statements {
match statement { match statement {
Statement::DeclareStatement(ident, expr) => { Statement::DeclareStatement {
lhs: ident,
rhs: expr,
..
} => {
let typ = expr.typ(ctx)?; let typ = expr.typ(ctx)?;
if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) { if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) {
// TODO: Shadowing? (illegal for now) // TODO: Shadowing? (illegal for now)
return Err(TypeError::builder() return Err(ctx.make_error(TypeErrorKind::VariableRedeclaration));
.context(ctx)
.kind(TypeErrorKind::VariableRedeclaration)
.build());
} }
} }
Statement::AssignStatement(ident, expr) => { Statement::AssignStatement {
lhs: ident,
rhs: expr,
..
} => {
let rhs_typ = expr.typ(ctx)?; let rhs_typ = expr.typ(ctx)?;
let Some(lhs_typ) = ctx.variables.get(ident) else { let Some(lhs_typ) = ctx.variables.get(ident) else {
return Err(TypeError::builder() return Err(ctx.make_error(TypeErrorKind::AssignUndeclared));
.context(ctx)
.kind(TypeErrorKind::AssignUndeclared)
.build());
}; };
// Ensure same type on both sides. // Ensure same type on both sides.
if rhs_typ != *lhs_typ { if rhs_typ != *lhs_typ {
return Err(TypeError::builder() return Err(ctx.make_error(TypeErrorKind::AssignmentMismatch {
.context(ctx)
.kind(TypeErrorKind::AssignmentMismatch {
lht: lhs_typ.clone(), lht: lhs_typ.clone(),
rht: rhs_typ.clone(), rht: rhs_typ.clone(),
}) }));
.build());
} }
} }
Statement::ReturnStatement(maybe_expr) => { Statement::ReturnStatement(return_stmt) => {
let expr_typ = if let Some(expr) = maybe_expr { return_stmt.typ(ctx)?;
expr.typ(ctx)?
} else {
Type::Unit
};
if let Some(typ) = &return_typ {
if expr_typ != *typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnStatementsMismatch)
.build());
} }
} else { Statement::CallStatement { call, span: _ } => {
return_typ = Some(expr_typ.clone());
}
}
Statement::CallStatement(call) => {
call.typ(ctx)?; call.typ(ctx)?;
} }
Statement::UseStatement(_path) => { Statement::UseStatement { .. } => {
// TODO: import the signatures (and types) // TODO: import the signatures (and types)
todo!()
} }
Statement::IfStatement(cond, block) => { Statement::IfStatement {
condition: cond,
then_block: block,
..
} => {
if cond.typ(ctx)? != Type::Bool { if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder() return Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool));
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
} }
block.typ(ctx)?; block.typ(ctx)?;
} }
Statement::WhileStatement(cond, block) => { Statement::WhileStatement {
condition: cond,
loop_block: block,
span: _,
} => {
if cond.typ(ctx)? != Type::Bool { if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder() return Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool));
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
} }
block.typ(ctx)?; block.typ(ctx)?;
} }
@ -277,25 +257,19 @@ impl TypeCheck for Block {
self.typ = Type::Unit; self.typ = Type::Unit;
Ok(Type::Unit) Ok(Type::Unit)
} }
// TODO/FIXME: find a way to return `return_typ` so that the
// top-level block (the function) can check if this return type
// (and eventually those from other block) matches the type of
// the function.
} }
} }
impl TypeCheck for Call { impl TypeCheck for Call {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
match &mut *self.callee { match &mut self.callee.expr {
Expr::Identifier { name, typ } => { Expr::Identifier { name, typ, .. } => {
let signature = match ctx.functions.get(name) { let signature = match ctx.functions.get(name) {
Some(sgn) => sgn.clone(), Some(sgn) => sgn.clone(),
None => { None => {
return Err(TypeError::builder() return Err(
.context(ctx) ctx.make_error(TypeErrorKind::UnknownFunctionCalled(name.clone()))
.kind(TypeErrorKind::UnknownFunctionCalled(name.clone())) )
.build())
} }
}; };
@ -315,10 +289,7 @@ impl TypeCheck for Call {
if args_types == *params_types { if args_types == *params_types {
Ok(self.typ.clone()) Ok(self.typ.clone())
} else { } else {
Err(TypeError::builder() Err(ctx.make_error(TypeErrorKind::WrongFunctionArguments))
.context(ctx)
.kind(TypeErrorKind::WrongFunctionArguments)
.build())
} }
} }
_ => unimplemented!("cannot call on expression other than identifier"), _ => unimplemented!("cannot call on expression other than identifier"),
@ -329,36 +300,40 @@ impl TypeCheck for Call {
impl TypeCheck for Expr { impl TypeCheck for Expr {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
match self { match self {
Expr::Identifier { name, typ } => { Expr::Identifier { name, typ, .. } => {
if let Some(ty) = ctx.variables.get(name) { if let Some(ty) = ctx.variables.get(name) {
*typ = ty.clone(); *typ = ty.clone();
Ok(typ.clone()) Ok(typ.clone())
} else { } else {
Err(TypeError::builder() Err(ctx.make_error(TypeErrorKind::UnknownIdentifier {
.context(ctx)
.kind(TypeErrorKind::UnknownIdentifier {
identifier: name.clone(), identifier: name.clone(),
}) }))
.build())
} }
} }
Expr::BooleanLiteral(_) => Ok(Type::Bool), Expr::BooleanLiteral(..) => Ok(Type::Bool),
Expr::IntegerLiteral(_) => Ok(Type::Int), Expr::IntegerLiteral(..) => Ok(Type::Int),
Expr::FloatLiteral(_) => Ok(Type::Float), Expr::FloatLiteral(..) => Ok(Type::Float),
Expr::UnaryExpression { op, inner } => { Expr::UnaryExpression { op, inner, .. } => {
let inner_type = &inner.typ(ctx)?; let inner_type = &inner.typ(ctx)?;
match (&op, inner_type) { match (&op, inner_type) {
(UnaryOperator::Not, Type::Bool) => Ok(Type::Bool), (UnaryOperator::Not, Type::Bool) => Ok(Type::Bool),
_ => Err(TypeError::builder() _ => Err(ctx.make_error(TypeErrorKind::InvalidUnaryOperator {
.context(ctx)
.kind(TypeErrorKind::InvalidUnaryOperator {
operator: *op, operator: *op,
inner: inner_type.clone(), inner: inner_type.clone(),
}) })),
.build()),
} }
} }
Expr::BinaryExpression { lhs, op, rhs, typ } => { Expr::BinaryExpression(BinaryExpression {
lhs,
op,
rhs,
typ,
op_span,
}) => {
let operator = error::BinOpAndSpan {
op: op.clone(),
span: *op_span,
};
let ty = match op { let ty = match op {
BinaryOperator::Add BinaryOperator::Add
| BinaryOperator::Sub | BinaryOperator::Sub
@ -368,32 +343,39 @@ impl TypeCheck for Expr {
| BinaryOperator::Or => { | BinaryOperator::Or => {
let left_type = &lhs.typ(ctx)?; let left_type = &lhs.typ(ctx)?;
let right_type = &rhs.typ(ctx)?; let right_type = &rhs.typ(ctx)?;
match (left_type, right_type) { match (left_type, right_type) {
(Type::Int, Type::Int) => Ok(Type::Int), (Type::Int, Type::Int) => Ok(Type::Int),
(Type::Float, Type::Float) => Ok(Type::Float), (Type::Float, Type::Float) => Ok(Type::Float),
(Type::Bool, Type::Bool) => Ok(Type::Bool), (Type::Bool, Type::Bool) => Ok(Type::Bool),
(_, _) => Err(TypeError::builder() (_, _) => Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator {
.context(ctx) operator,
.kind(TypeErrorKind::InvalidBinaryOperator { lhs: TypeAndSpan {
operator: op.clone(), ty: left_type.clone(),
lht: left_type.clone(), span: lhs.span,
rht: right_type.clone(), },
}) rhs: TypeAndSpan {
.build()), ty: right_type.clone(),
span: rhs.span,
},
})),
} }
} }
BinaryOperator::Equal | BinaryOperator::NotEqual => { BinaryOperator::Equal | BinaryOperator::NotEqual => {
let lhs_type = lhs.typ(ctx)?; let lhs_type = lhs.typ(ctx)?;
let rhs_type = rhs.typ(ctx)?; let rhs_type = rhs.typ(ctx)?;
if lhs_type != rhs_type { if lhs_type != rhs_type {
return Err(TypeError::builder() return Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator {
.context(ctx) operator,
.kind(TypeErrorKind::InvalidBinaryOperator { lhs: TypeAndSpan {
operator: op.clone(), ty: lhs_type.clone(),
lht: lhs_type.clone(), span: lhs.span,
rht: rhs_type.clone(), },
}) rhs: TypeAndSpan {
.build()); ty: rhs_type.clone(),
span: rhs.span,
},
}));
} }
Ok(Type::Bool) Ok(Type::Bool)
} }
@ -402,14 +384,17 @@ impl TypeCheck for Expr {
let rhs_type = lhs.typ(ctx)?; let rhs_type = lhs.typ(ctx)?;
match (&lhs_type, &rhs_type) { match (&lhs_type, &rhs_type) {
(Type::Int, Type::Int) => Ok(Type::Int), (Type::Int, Type::Int) => Ok(Type::Int),
_ => Err(TypeError::builder() _ => Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator {
.context(ctx) operator,
.kind(TypeErrorKind::InvalidBinaryOperator { lhs: TypeAndSpan {
operator: op.clone(), ty: lhs_type.clone(),
lht: lhs_type.clone(), span: lhs.span,
rht: rhs_type.clone(), },
}) rhs: TypeAndSpan {
.build()), ty: rhs_type.clone(),
span: rhs.span,
},
})),
} }
} }
}; };
@ -427,18 +412,12 @@ impl TypeCheck for Expr {
typ, typ,
} => { } => {
if cond.typ(ctx)? != Type::Bool { if cond.typ(ctx)? != Type::Bool {
Err(TypeError::builder() Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool))
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build())
} else { } else {
let then_body_type = then_body.typ(ctx)?; let then_body_type = then_body.typ(ctx)?;
let else_type = else_body.typ(ctx)?; let else_type = else_body.typ(ctx)?;
if then_body_type != else_type { if then_body_type != else_type {
Err(TypeError::builder() Err(ctx.make_error(TypeErrorKind::IfElseMismatch))
.context(ctx)
.kind(TypeErrorKind::IfElseMismatch)
.build())
} else { } else {
// XXX: opt: return ref to avoid cloning // XXX: opt: return ref to avoid cloning
*typ = then_body_type.clone(); *typ = then_body_type.clone();
@ -449,3 +428,39 @@ impl TypeCheck for Expr {
} }
} }
} }
impl TypeCheck for ReturnStatement {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
let ty = if let Some(expr) = &mut self.expr {
expr.typ(ctx)?
} else {
Type::Unit
};
// Check if the returned type is coherent with the function's signature
let func_type = &ctx.functions.get(ctx.function.as_ref().unwrap()).unwrap().1;
if ty != *func_type {
return Err(
ctx.make_error(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
return_expr: self.expr.as_ref().map(|e| TypeAndSpan {
ty: ty.clone(),
span: e.span,
}),
return_stmt: TypeAndSpan {
ty: ty.clone(),
span: self.span,
},
}),
);
};
Ok(ty)
}
}
impl TypeCheck for SExpr {
#[inline]
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
self.expr.typ(ctx)
}
}

View file

@ -1,33 +1,40 @@
use crate::{ use crate::{
ast::ModulePath, ast::ModulePath,
parsing::parse_as_module, parsing::{DefaultParser, Parser},
typing::{ typing::error::*,
error::{TypeError, TypeErrorKind}, typing::*,
BinaryOperator, Type,
},
}; };
#[cfg(test)]
use pretty_assertions::assert_eq;
#[test] #[test]
fn addition_int_and_float() { fn addition_int_and_float() {
let source = "fn add(a: int, b: float) int { a + b }"; let source = "fn add(a: int, b: float) int { a + b }";
let mut ast = parse_as_module(source, ModulePath::default()).unwrap(); let mut ast = DefaultParser::default()
.parse_as_module(source, ModulePath::default(), 0)
.unwrap();
let res = ast.type_check(); let res = ast.type_check();
assert!(res.is_err_and(|e| e.kind assert!(res.is_err_and(|errors| errors.len() == 1
== TypeErrorKind::InvalidBinaryOperator { && matches!(errors[0].kind, TypeErrorKind::InvalidBinaryOperator { .. })));
operator: BinaryOperator::Add,
lht: Type::Int,
rht: Type::Float
}));
} }
#[test] #[test]
fn return_int_instead_of_float() { fn return_int_instead_of_float() {
let source = "fn add(a: int, b: int) float { a + b }"; let source = "fn add(a: int, b: int) float { a + b }";
let mut ast = parse_as_module(source, ModulePath::default()).unwrap(); let mut ast = DefaultParser::default()
.parse_as_module(source, ModulePath::default(), 0)
.unwrap();
let res = ast.type_check(); let res = ast.type_check();
assert!(res.is_err_and(|e| e.kind assert_eq!(
== TypeErrorKind::BlockTypeDoesNotMatchFunctionType { res,
Err(vec![TypeError {
file: None,
module: ModulePath::default(),
function: Some("add".to_string()),
kind: TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
block_type: Type::Int, block_type: Type::Int,
function_type: Type::Float }
})); }])
);
} }