add pretty diagnostics

This commit is contained in:
Romain Paquet 2024-07-03 19:59:12 +02:00
parent e157bf036a
commit f415c4abbe
12 changed files with 1037 additions and 603 deletions

View file

@ -5,10 +5,15 @@ edition = "2021"
[dependencies]
clap = { version = "4.5.7", features = ["derive"] }
cranelift = "0.108.1"
cranelift-jit = "0.108.1"
cranelift-module = "0.108.1"
cranelift-native = "0.108.1"
cranelift = "0.109.0"
cranelift-jit = "0.109.0"
cranelift-module = "0.109.0"
cranelift-native = "0.109.0"
lazy_static = "1.4.0"
pest = "2.7.4"
pest_derive = "2.7.4"
ariadne = "0.4.1"
anyhow = "1.0.86"
[dev-dependencies]
pretty_assertions = "1.4.0"

View file

@ -1,17 +1,27 @@
use crate::ast::*;
use crate::typing::Type;
#[derive(Debug, PartialEq)]
pub struct SExpr {
pub expr: Expr,
pub span: Span,
}
#[derive(Debug, PartialEq)]
pub struct BinaryExpression {
pub lhs: Box<SExpr>,
pub op: BinaryOperator,
pub op_span: Span,
pub rhs: Box<SExpr>,
pub typ: Type,
}
#[derive(Debug, PartialEq)]
pub enum Expr {
BinaryExpression {
lhs: Box<Expr>,
op: BinaryOperator,
rhs: Box<Expr>,
typ: Type,
},
BinaryExpression(BinaryExpression),
UnaryExpression {
op: UnaryOperator,
inner: Box<Expr>,
inner: Box<SExpr>,
},
Identifier {
name: String,
@ -21,9 +31,9 @@ pub enum Expr {
Block(Box<Block>),
/// Last field is either Expr::Block or Expr::IfExpr
IfExpr {
cond: Box<Expr>,
cond: Box<SExpr>,
then_body: Box<Block>,
else_body: Box<Expr>,
else_body: Box<SExpr>,
typ: Type,
},
// Literals
@ -45,22 +55,12 @@ impl Block {
impl Expr {
pub fn ty(&self) -> Type {
match self {
Expr::BinaryExpression {
lhs: _,
op: _,
rhs: _,
typ,
} => typ.clone(),
Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here
Expr::Identifier { name: _, typ } => typ.clone(),
Expr::BinaryExpression(BinaryExpression { typ, .. }) => typ.clone(),
Expr::UnaryExpression { inner, .. } => inner.ty(), // XXX: problems will arise here
Expr::Identifier { typ, .. } => typ.clone(),
Expr::Call(call) => call.typ.clone(),
Expr::Block(block) => block.typ.clone(),
Expr::IfExpr {
cond: _,
then_body: _,
else_body: _,
typ,
} => typ.clone(),
Expr::IfExpr { typ, .. } => typ.clone(),
Expr::UnitLiteral => Type::Unit,
Expr::BooleanLiteral(_) => Type::Bool,
Expr::IntegerLiteral(_) => Type::Int,
@ -69,3 +69,10 @@ impl Expr {
}
}
}
impl SExpr {
#[inline]
pub fn ty(&self) -> Type {
self.expr.ty()
}
}

View file

@ -1,9 +1,11 @@
pub mod expr;
pub use expr::Expr;
pub use expr::{BinaryExpression, Expr, SExpr};
use crate::typing::Type;
use std::path::Path;
use ariadne;
use std::{fmt::Display, path::Path};
#[derive(Debug, PartialEq, Clone)]
pub enum BinaryOperator {
@ -20,6 +22,22 @@ pub enum BinaryOperator {
NotEqual,
}
impl Display for BinaryOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
BinaryOperator::And => "&&",
BinaryOperator::Or => "||",
BinaryOperator::Add => "+",
BinaryOperator::Sub => "-",
BinaryOperator::Mul => "*",
BinaryOperator::Div => "/",
BinaryOperator::Modulo => "%",
BinaryOperator::Equal => "==",
BinaryOperator::NotEqual => "!=",
})
}
}
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum UnaryOperator {
Not,
@ -27,12 +45,32 @@ pub enum UnaryOperator {
pub type Identifier = String;
#[derive(Debug, PartialEq)]
pub struct Location {
pub line_col: (usize, usize),
pub type SourceId = u32;
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct Span {
pub source: SourceId,
pub start: usize,
pub end: usize,
}
#[derive(Debug, PartialEq, Clone, Default)]
impl ariadne::Span for Span {
type SourceId = SourceId;
fn source(&self) -> &Self::SourceId {
&self.source
}
fn start(&self) -> usize {
self.start
}
fn end(&self) -> usize {
self.end
}
}
#[derive(Debug, PartialEq, Clone, Default, Eq, Hash)]
pub struct ModulePath {
components: Vec<String>,
}
@ -65,7 +103,7 @@ impl From<&Path> for ModulePath {
.map(|component| match component {
std::path::Component::Normal(n) => {
if meta.is_file() {
n.to_str().unwrap().split(".").nth(0).unwrap().to_string()
n.to_str().unwrap().split('.').nth(0).unwrap().to_string()
} else if meta.is_dir() {
n.to_str().unwrap().to_string()
} else {
@ -91,22 +129,51 @@ impl From<&str> for ModulePath {
#[derive(Eq, PartialEq, Debug)]
pub struct Import(pub String);
#[derive(Debug, PartialEq)]
pub struct ReturnStatement {
pub expr: Option<SExpr>,
pub span: Span,
}
#[derive(Debug, PartialEq)]
pub enum Statement {
DeclareStatement(Identifier, Box<Expr>),
AssignStatement(Identifier, Box<Expr>),
ReturnStatement(Option<Expr>),
CallStatement(Box<Call>),
UseStatement(Box<Import>),
IfStatement(Box<Expr>, Box<Block>),
WhileStatement(Box<Expr>, Box<Block>),
DeclareStatement {
lhs: Identifier,
rhs: Box<SExpr>,
span: Span,
},
AssignStatement {
lhs: Identifier,
rhs: Box<SExpr>,
span: Span,
},
ReturnStatement(ReturnStatement),
CallStatement {
call: Box<Call>,
span: Span,
},
UseStatement {
import: Box<Import>,
span: Span,
},
IfStatement {
condition: Box<SExpr>,
then_block: Box<Block>,
span: Span,
},
WhileStatement {
condition: Box<SExpr>,
loop_block: Box<Block>,
span: Span,
},
}
#[derive(Debug, PartialEq)]
pub struct Block {
pub statements: Vec<Statement>,
pub value: Option<Expr>,
pub value: Option<SExpr>,
pub typ: Type,
pub span: Option<Span>,
}
impl Block {
@ -115,6 +182,7 @@ impl Block {
typ: Type::Unit,
statements: Vec::with_capacity(0),
value: None,
span: None,
}
}
}
@ -129,8 +197,9 @@ pub struct FunctionDefinition {
pub name: Identifier,
pub parameters: Vec<Parameter>,
pub return_type: Option<Type>,
pub return_type_span: Option<Span>,
pub body: Box<Block>,
pub location: Location,
pub span: Span,
}
#[derive(Debug, PartialEq, Default)]
@ -159,8 +228,8 @@ impl Module {
#[derive(Debug, PartialEq)]
pub struct Call {
pub callee: Box<Expr>,
pub args: Vec<Expr>,
pub callee: Box<SExpr>,
pub args: Vec<SExpr>,
pub typ: Type,
}

View file

@ -1,15 +1,17 @@
use crate::{
ast::{
self, BinaryOperator, ModulePath, UnaryOperator,
{expr::Expr, FunctionDefinition, Statement},
self, expr::BinaryExpression, BinaryOperator, Expr, FunctionDefinition, ModulePath,
ReturnStatement, SourceId, Statement, UnaryOperator,
},
parsing,
typing::{CheckedModule, Type},
parsing::{DefaultParser, Parser},
typing::Type,
SourceCache,
};
use ariadne::Cache as _;
use cranelift::{codegen::ir::UserFuncName, prelude::*};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module};
use std::{collections::HashMap, fs, ops::Deref};
use std::collections::HashMap;
/// The basic JIT class.
pub struct JIT {
@ -30,6 +32,9 @@ pub struct JIT {
/// Whether to print CLIR during compilation
pub dump_clir: bool,
/// Parser used to build the AST
pub parser: DefaultParser,
}
impl Default for JIT {
@ -59,18 +64,34 @@ impl Default for JIT {
data_desc: DataDescription::new(),
module,
dump_clir: false,
parser: DefaultParser::default(),
}
}
}
impl JIT {
/// Compile source code into machine code.
pub fn compile(&mut self, input: &str, namespace: ModulePath) -> Result<*const u8, String> {
pub fn compile(
&mut self,
input: &str,
namespace: ModulePath,
id: SourceId,
) -> Result<*const u8, String> {
let mut source_cache = (0u32, ariadne::Source::from(input));
// Parse the source code into an AST
let ast = parsing::parse_as_module(input, namespace)
.map_err(|x| format!("Parsing error: {x}"))?
.type_check()
.map_err(|x| format!("Typing error: {x}"))?;
let mut ast = self
.parser
.parse_as_module(input, namespace, id)
.map_err(|x| format!("Parsing error: {x}"))?;
ast.type_check()
.map_err(|errors| {
errors
.iter()
.for_each(|e| e.to_report(&ast).eprint(&mut source_cache).unwrap());
})
.unwrap();
// Translate the AST into Cranelift IR
self.translate(&ast)?;
@ -89,20 +110,24 @@ impl JIT {
}
}
pub fn compile_file(&mut self, path: &str) -> Result<*const u8, String> {
pub fn compile_file(
&mut self,
path: &str,
id: SourceId,
source_cache: &mut SourceCache,
) -> Result<*const u8, String> {
self.compile(
fs::read_to_string(path)
.map_err(|x| format!("Cannot open {}: {}", path, x))?
.as_str(),
source_cache
.fetch(&id)
.map(|s| s.text())
.map_err(|e| format!("{:?}", e))?,
AsRef::<std::path::Path>::as_ref(path).into(),
id,
)
}
/// Translate language AST into Cranelift IR.
fn translate(&mut self, ast: &CheckedModule) -> Result<(), String> {
// Dump contract-holding wrapper type
let ast = &ast.0;
fn translate(&mut self, ast: &ast::Module) -> Result<(), String> {
let mut signatures: Vec<Signature> = Vec::with_capacity(ast.functions.len());
let mut func_ids: Vec<FuncId> = Vec::with_capacity(ast.functions.len());
@ -199,7 +224,7 @@ impl JIT {
// Emit the final return instruction.
if let Some(return_expr) = &function.body.value {
let return_value = translator.translate_expr(&return_expr);
let return_value = translator.translate_expr(&return_expr.expr);
translator.builder.ins().return_(&[return_value]);
} else {
translator.builder.ins().return_(&[]);
@ -234,18 +259,26 @@ struct FunctionTranslator<'a> {
impl<'a> FunctionTranslator<'a> {
fn translate_statement(&mut self, stmt: &Statement) -> Option<Value> {
match stmt {
Statement::AssignStatement(name, expr) => {
Statement::AssignStatement {
lhs: name,
rhs: expr,
..
} => {
// `def_var` is used to write the value of a variable. Note that
// variables can have multiple definitions. Cranelift will
// convert them into SSA form for itself automatically.
let new_value = self.translate_expr(expr);
let new_value = self.translate_expr(&expr.expr);
let variable = self.variables.get(name).unwrap();
self.builder.def_var(*variable, new_value);
Some(new_value)
}
Statement::DeclareStatement(name, expr) => {
let value = self.translate_expr(expr);
Statement::DeclareStatement {
lhs: name,
rhs: expr,
..
} => {
let value = self.translate_expr(&expr.expr);
let variable = Variable::from_u32(self.variables.len() as u32);
self.builder
.declare_var(variable, self.translate_type(&expr.ty()));
@ -254,10 +287,12 @@ impl<'a> FunctionTranslator<'a> {
Some(value)
}
Statement::ReturnStatement(maybe_expr) => {
Statement::ReturnStatement(ReturnStatement {
expr: maybe_expr, ..
}) => {
// TODO: investigate tail call
let values = if let Some(expr) = maybe_expr {
vec![self.translate_expr(expr)]
vec![self.translate_expr(&expr.expr)]
} else {
// XXX: urgh
Vec::with_capacity(0)
@ -269,12 +304,16 @@ impl<'a> FunctionTranslator<'a> {
None
}
Statement::CallStatement(call) => self.translate_call(call),
Statement::CallStatement { call, .. } => self.translate_call(call),
Statement::UseStatement(_) => todo!(),
Statement::UseStatement { .. } => todo!(),
Statement::IfStatement(cond, then_body) => {
let condition_value = self.translate_expr(cond);
Statement::IfStatement {
condition: cond,
then_block: then_body,
..
} => {
let condition_value = self.translate_expr(&cond.expr);
let then_block = self.builder.create_block();
let merge_block = self.builder.create_block();
@ -294,7 +333,7 @@ impl<'a> FunctionTranslator<'a> {
None
}
Statement::WhileStatement(_, _) => todo!(),
Statement::WhileStatement { .. } => todo!(),
}
}
@ -302,11 +341,11 @@ impl<'a> FunctionTranslator<'a> {
match expr {
Expr::UnitLiteral => unreachable!(),
Expr::BooleanLiteral(imm) => self.builder.ins().iconst(types::I8, i64::from(*imm)),
Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm),
Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm),
Expr::BooleanLiteral(imm, ..) => self.builder.ins().iconst(types::I8, i64::from(*imm)),
Expr::IntegerLiteral(imm, ..) => self.builder.ins().iconst(types::I32, *imm),
Expr::FloatLiteral(imm, ..) => self.builder.ins().f64const(*imm),
Expr::StringLiteral(s) => {
Expr::StringLiteral(s, ..) => {
let id = self.module.declare_anonymous_data(false, false).unwrap();
let bytes: Box<[u8]> = s.as_bytes().into();
self.data_desc.define(bytes);
@ -318,14 +357,9 @@ impl<'a> FunctionTranslator<'a> {
.global_value(self.module.isa().pointer_type(), gv)
}
Expr::BinaryExpression {
lhs,
op,
rhs,
typ: _,
} => {
let lhs_value = self.translate_expr(lhs);
let rhs_value = self.translate_expr(rhs);
Expr::BinaryExpression(BinaryExpression { lhs, rhs, op, .. }) => {
let lhs_value = self.translate_expr(&lhs.expr);
let rhs_value = self.translate_expr(&rhs.expr);
match (lhs.ty(), lhs.ty()) {
(Type::Int, Type::Int) => match op {
@ -361,8 +395,9 @@ impl<'a> FunctionTranslator<'a> {
then_body,
else_body,
typ,
..
} => {
let condition_value = self.translate_expr(cond);
let condition_value = self.translate_expr(&cond.expr);
let then_block = self.builder.create_block();
let else_block = self.builder.create_block();
@ -384,11 +419,11 @@ impl<'a> FunctionTranslator<'a> {
self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block);
for stmt in &then_body.statements {
self.translate_statement(&stmt);
self.translate_statement(stmt);
}
let then_return_value = match &then_body.value {
Some(val) => vec![self.translate_expr(val)],
Some(val) => vec![self.translate_expr(&val.expr)],
None => Vec::with_capacity(0),
};
@ -399,9 +434,9 @@ impl<'a> FunctionTranslator<'a> {
self.builder.seal_block(else_block);
// XXX: the else can be just an expression: do we always need to
// make a second branch in that case? Or leave it to cranelift?
let else_return_value = match **else_body {
let else_return_value = match else_body.expr {
Expr::UnitLiteral => Vec::with_capacity(0),
_ => vec![self.translate_expr(else_body)],
_ => vec![self.translate_expr(&else_body.expr)],
};
// Jump to the merge block, passing it the block return value.
@ -420,8 +455,8 @@ impl<'a> FunctionTranslator<'a> {
phi
}
Expr::UnaryExpression { op, inner } => {
let inner_value = self.translate_expr(inner);
Expr::UnaryExpression { op, inner, .. } => {
let inner_value = self.translate_expr(&inner.expr);
match op {
// XXX: This should not be a literal translation
UnaryOperator::Not => {
@ -431,13 +466,13 @@ impl<'a> FunctionTranslator<'a> {
}
}
Expr::Identifier { name, typ: _ } => {
Expr::Identifier { name, .. } => {
self.builder.use_var(*self.variables.get(name).unwrap())
}
Expr::Call(call) => self.translate_call(call).unwrap(),
Expr::Call(call, ..) => self.translate_call(call).unwrap(),
Expr::Block(block) => self.translate_block(block).unwrap(),
Expr::Block(block, ..) => self.translate_block(block).unwrap(),
}
}
@ -445,16 +480,16 @@ impl<'a> FunctionTranslator<'a> {
for stmt in &block.statements {
self.translate_statement(stmt);
}
if let Some(block_value) = &block.value {
Some(self.translate_expr(block_value))
} else {
None
}
block
.value
.as_ref()
.map(|block_value| self.translate_expr(&block_value.expr))
}
fn translate_call(&mut self, call: &ast::Call) -> Option<Value> {
match call.callee.deref() {
Expr::Identifier { name, typ: _ } => {
match &call.callee.expr {
Expr::Identifier { name, .. } => {
let func_ref = if let Some(func_or_data_id) = self.module.get_name(name.as_ref()) {
if let FuncOrDataId::Func(func_id) = func_or_data_id {
self.module.declare_func_in_func(func_id, self.builder.func)
@ -465,7 +500,11 @@ impl<'a> FunctionTranslator<'a> {
todo!()
};
let args: Vec<Value> = call.args.iter().map(|a| self.translate_expr(a)).collect();
let args: Vec<Value> = call
.args
.iter()
.map(|a| self.translate_expr(&a.expr))
.collect();
let call_inst = self.builder.ins().call(func_ref, &args);
let results = self.builder.inst_results(call_inst);

View file

@ -1,14 +1,20 @@
pub mod ast;
pub mod jit;
pub mod parsing;
pub mod source;
pub mod typing;
use clap::{Parser, Subcommand};
use std::default::Default;
use std::path::PathBuf;
use clap::{Parser as ClapParser, Subcommand};
use crate::ast::Module;
use crate::parsing::Parser;
use crate::source::SourceCache;
/// Experimental compiler for lila
#[derive(Parser, Debug)]
#[derive(ClapParser, Debug)]
#[command(author = "Romain P. <rpqt@rpqt.fr>")]
#[command(version, about, long_about = None)]
struct Cli {
@ -51,21 +57,27 @@ enum Commands {
},
}
fn parse(files: &Vec<String>) -> Vec<Module> {
fn parse(files: &[String]) -> Vec<Module> {
let mut parser = parsing::DefaultParser::default();
let paths = files.iter().map(std::path::Path::new);
paths
.map(|path| match parsing::parse_file(&path) {
.enumerate()
.map(|(i, path)| match parser.parse_file(path, i as u32) {
Ok(module) => module,
Err(e) => panic!("Parsing error: {:#?}", e),
})
.collect()
}
fn check(modules: &mut Vec<Module>) {
while let Some(module) = modules.pop() {
if let Err(e) = module.type_check() {
eprintln!("{}", e);
return;
fn check(modules: &mut Vec<Module>, source_cache: &mut SourceCache) {
for module in modules {
if let Err(errors) = module.type_check() {
for error in errors {
error
.to_report(module)
.eprint(&mut *source_cache)
.expect("cannot write error to stderr");
}
}
}
}
@ -83,20 +95,32 @@ fn main() {
}
println!("Parsing OK");
}
Commands::TypeCheck { files, dump_ast } => {
let mut source_cache = SourceCache {
paths: files.iter().map(PathBuf::from).collect(),
file_cache: ariadne::FileCache::default(),
};
let mut modules = parse(files);
check(&mut modules);
check(&mut modules, &mut source_cache);
if *dump_ast {
for module in &modules {
println!("{:#?}", &module);
}
}
}
Commands::Compile { files, dump_clir } | Commands::Run { files, dump_clir } => {
let mut source_cache = SourceCache {
paths: files.iter().map(PathBuf::from).collect(),
file_cache: ariadne::FileCache::default(),
};
let mut jit = jit::JIT::default();
jit.dump_clir = *dump_clir;
for file in files {
match jit.compile_file(file) {
for (id, file) in files.iter().enumerate() {
match jit.compile_file(file, id as u32, &mut source_cache) {
Err(e) => eprintln!("{}", e),
Ok(code) => {
println!("Compiled {}", file);

View file

@ -1,14 +1,13 @@
use std::fs;
use std::path::Path;
use pest::error::Error;
use expr::BinaryExpression;
use pest::iterators::Pair;
use pest::pratt_parser::PrattParser;
use pest::Parser;
use pest::Parser as PestParser;
use ReturnStatement;
use crate::ast::Module;
use crate::ast::*;
use crate::ast::{Import, ModulePath};
use crate::typing::Type;
#[derive(pest_derive::Parser)]
@ -33,24 +32,38 @@ lazy_static::lazy_static! {
};
}
pub fn parse_file(path: &Path) -> Result<Module, Error<Rule>> {
let source = fs::read_to_string(&path).expect("could not read source file");
#[derive(Default)]
pub struct Parser {
source: SourceId,
}
impl crate::parsing::Parser for Parser {
fn parse_file(&mut self, path: &Path, id: SourceId) -> anyhow::Result<Module> {
let source = fs::read_to_string(path)?;
let module_path = ModulePath::from(path);
let mut module = parse_as_module(&source, module_path)?;
let mut module = self.parse_as_module(&source, module_path, id)?;
module.file = Some(path.to_owned());
Ok(module)
}
pub fn parse_as_module(source: &str, path: ModulePath) -> Result<Module, Error<Rule>> {
let mut pairs = LilaParser::parse(Rule::source_file, &source)?;
fn parse_as_module(
&mut self,
source: &str,
path: ModulePath,
id: SourceId,
) -> anyhow::Result<Module> {
self.source = id;
let mut pairs = LilaParser::parse(Rule::source_file, source)?;
assert!(pairs.len() == 1);
let module = parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path);
let module = self.parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path);
Ok(module)
}
}
pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
impl Parser {
fn parse_module(&self, pair: Pair<Rule>, path: ModulePath) -> Module {
assert!(pair.as_rule() == Rule::module_items);
let mut module = Module::new(path);
@ -59,13 +72,13 @@ pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
for pair in pairs {
match pair.as_rule() {
Rule::definition => {
let def = parse_definition(pair.into_inner().next().unwrap());
let def = self.parse_definition(pair.into_inner().next().unwrap());
match def {
Definition::FunctionDefinition(func) => module.functions.push(func),
}
}
Rule::use_statement => {
let path = parse_import(pair.into_inner().next().unwrap());
let path = self.parse_import(pair.into_inner().next().unwrap());
module.imports.push(path);
}
_ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()),
@ -75,14 +88,15 @@ pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
module
}
fn parse_block(pair: Pair<Rule>) -> Block {
fn parse_block(&self, pair: Pair<Rule>) -> Block {
let mut statements = vec![];
let mut value = None;
let span = self.make_span(&pair);
for pair in pair.into_inner() {
match pair.as_rule() {
Rule::statement => statements.push(parse_statement(pair)),
Rule::expr => value = Some(parse_expression(pair)),
Rule::statement => statements.push(self.parse_statement(pair)),
Rule::expr => value = Some(self.parse_expression(pair)),
_ => panic!("unexpected rule {:?} in block", pair.as_rule()),
}
}
@ -91,75 +105,104 @@ fn parse_block(pair: Pair<Rule>) -> Block {
statements,
value,
typ: Type::Undefined,
span: Some(span),
}
}
fn parse_statement(pair: Pair<Rule>) -> Statement {
fn parse_statement(&self, pair: Pair<Rule>) -> Statement {
let pair = pair.into_inner().next().unwrap();
let span = self.make_span(&pair);
match pair.as_rule() {
Rule::assign_statement => {
let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap());
Statement::AssignStatement(identifier, Box::new(expr))
let expr = self.parse_expression(pairs.next().unwrap());
Statement::AssignStatement {
lhs: identifier,
rhs: Box::new(expr),
span,
}
}
Rule::declare_statement => {
let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap());
Statement::DeclareStatement(identifier, Box::new(expr))
let expr = self.parse_expression(pairs.next().unwrap());
Statement::DeclareStatement {
lhs: identifier,
rhs: Box::new(expr),
span,
}
}
Rule::return_statement => {
let expr = if let Some(pair) = pair.into_inner().next() {
Some(parse_expression(pair))
} else {
None
};
Statement::ReturnStatement(expr)
let expr = pair
.into_inner()
.next()
.map(|expr| self.parse_expression(expr));
Statement::ReturnStatement(ReturnStatement { expr, span })
}
Rule::call_statement => {
let call = parse_call(pair.into_inner().next().unwrap());
Statement::CallStatement(Box::new(call))
let call = self.parse_call(pair.into_inner().next().unwrap());
Statement::CallStatement {
call: Box::new(call),
span,
}
}
Rule::use_statement => {
let import = parse_import(pair.into_inner().next().unwrap());
Statement::UseStatement(Box::new(import))
let import = self.parse_import(pair.into_inner().next().unwrap());
Statement::UseStatement {
import: Box::new(import),
span,
}
}
Rule::if_statement => {
let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap());
let condition = self.parse_expression(pairs.next().unwrap());
let block = self.parse_block(pairs.next().unwrap());
if pairs.next().is_some() {
todo!("implement if-statements with else branch (and else if)")
}
Statement::IfStatement(Box::new(condition), Box::new(block))
Statement::IfStatement {
condition: Box::new(condition),
then_block: Box::new(block),
span,
}
}
Rule::while_statement => {
let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap());
Statement::WhileStatement(Box::new(condition), Box::new(block))
let condition = self.parse_expression(pairs.next().unwrap());
let block = self.parse_block(pairs.next().unwrap());
Statement::WhileStatement {
condition: Box::new(condition),
loop_block: Box::new(block),
span,
}
}
_ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()),
}
}
fn parse_import(pair: Pair<Rule>) -> Import {
fn parse_import(&self, pair: Pair<Rule>) -> Import {
Import(pair.as_str().to_string())
}
fn parse_call(pair: Pair<Rule>) -> Call {
fn parse_call(&self, pair: Pair<Rule>) -> Call {
let mut pairs = pair.into_inner();
// TODO: support calls on more than identifiers (needs grammar change)
let callee = Expr::Identifier {
name: pairs.next().unwrap().as_str().to_string(),
let pair = pairs.next().unwrap();
let callee = SExpr {
expr: Expr::Identifier {
name: pair.as_str().to_string(),
typ: Type::Undefined,
},
span: self.make_span(&pair),
};
let args: Vec<Expr> = pairs
let args: Vec<SExpr> = pairs
.next()
.unwrap()
.into_inner()
.map(parse_expression)
.map(|arg| self.parse_expression(arg))
.collect();
Call {
@ -169,13 +212,25 @@ fn parse_call(pair: Pair<Rule>) -> Call {
}
}
fn parse_expression(pair: Pair<Rule>) -> Expr {
fn parse_expression(&self, pair: Pair<Rule>) -> SExpr {
let span = self.make_span(&pair);
let pairs = pair.into_inner();
PRATT_PARSER
.map_primary(|primary| match primary.as_rule() {
Rule::integer_literal => Expr::IntegerLiteral(primary.as_str().parse().unwrap()),
Rule::float_literal => Expr::FloatLiteral(primary.as_str().parse().unwrap()),
Rule::string_literal => Expr::StringLiteral(
let mut map = PRATT_PARSER
.map_primary(|primary| {
let span = self.make_span(&primary);
match primary.as_rule() {
Rule::integer_literal => SExpr {
expr: Expr::IntegerLiteral(primary.as_str().parse().unwrap()),
span,
},
Rule::float_literal => SExpr {
expr: Expr::FloatLiteral(primary.as_str().parse().unwrap()),
span,
},
Rule::string_literal => SExpr {
expr: Expr::StringLiteral(
primary
.into_inner()
.next()
@ -184,34 +239,59 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
.parse()
.unwrap(),
),
Rule::expr => parse_expression(primary),
Rule::ident => Expr::Identifier {
span,
},
Rule::expr => self.parse_expression(primary),
Rule::ident => SExpr {
expr: Expr::Identifier {
name: primary.as_str().to_string(),
typ: Type::Undefined,
},
Rule::call => Expr::Call(Box::new(parse_call(primary))),
Rule::block => Expr::Block(Box::new(parse_block(primary))),
span,
},
Rule::call => SExpr {
expr: Expr::Call(Box::new(self.parse_call(primary))),
span,
},
Rule::block => SExpr {
expr: Expr::Block(Box::new(self.parse_block(primary))),
span,
},
Rule::if_expr => {
let mut pairs = primary.into_inner();
let condition = parse_expression(pairs.next().unwrap());
let true_block = parse_block(pairs.next().unwrap());
let else_value = parse_expression(pairs.next().unwrap());
Expr::IfExpr {
let condition = self.parse_expression(pairs.next().unwrap());
let true_block = self.parse_block(pairs.next().unwrap());
let else_value = self.parse_expression(pairs.next().unwrap());
SExpr {
expr: Expr::IfExpr {
cond: Box::new(condition),
then_body: Box::new(true_block),
else_body: Box::new(else_value),
typ: Type::Undefined,
},
span,
}
}
Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() {
Rule::boolean_literal => SExpr {
expr: Expr::BooleanLiteral(match primary.as_str() {
"true" => true,
"false" => false,
_ => unreachable!(),
}),
span,
},
_ => unreachable!(
"Unexpected rule '{:?}' in primary expression",
primary.as_rule()
),
}
})
.map_infix(|lhs, op, rhs| {
let operator = match op.as_rule() {
@ -226,27 +306,30 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
Rule::or => BinaryOperator::Or,
_ => unreachable!(),
};
Expr::BinaryExpression {
let expr = Expr::BinaryExpression(BinaryExpression {
lhs: Box::new(lhs),
op: operator,
op_span: self.make_span(&op),
rhs: Box::new(rhs),
typ: Type::Undefined,
}
});
SExpr { expr, span }
})
.map_prefix(|op, inner| {
let operator = match op.as_rule() {
Rule::not => UnaryOperator::Not,
_ => unreachable!(),
};
Expr::UnaryExpression {
let expr = Expr::UnaryExpression {
op: operator,
inner: Box::new(inner),
}
})
.parse(pairs)
};
SExpr { expr, span }
});
map.parse(pairs)
}
fn parse_parameter(pair: Pair<Rule>) -> Parameter {
fn parse_parameter(&self, pair: Pair<Rule>) -> Parameter {
assert!(pair.as_rule() == Rule::parameter);
let mut pair = pair.into_inner();
let name = pair.next().unwrap().as_str().to_string();
@ -254,38 +337,53 @@ fn parse_parameter(pair: Pair<Rule>) -> Parameter {
Parameter { name, typ }
}
fn parse_definition(pair: Pair<Rule>) -> Definition {
fn parse_definition(&self, pair: Pair<Rule>) -> Definition {
match pair.as_rule() {
Rule::func_def => {
let line_col = pair.line_col();
let span = self.make_span(&pair);
let mut pairs = pair.into_inner();
let name = pairs.next().unwrap().as_str().to_string();
let parameters: Vec<Parameter> = pairs
.next()
.unwrap()
.into_inner()
.map(parse_parameter)
.map(|param| self.parse_parameter(param))
.collect();
let pair = pairs.next().unwrap();
// Before the block there is an optional return type
let (return_type, pair) = match pair.as_rule() {
Rule::ident => (Some(Type::from(pair.as_str())), pairs.next().unwrap()),
Rule::block => (None, pair),
let (return_type, return_type_span, pair) = match pair.as_rule() {
Rule::ident => (
Some(Type::from(pair.as_str())),
Some(self.make_span(&pair)),
pairs.next().unwrap(),
),
Rule::block => (None, None, pair),
_ => unreachable!(
"Unexpected rule '{:?}' in function definition, expected return type or block",
pair.as_rule()
),
};
let body = parse_block(pair);
let body = self.parse_block(pair);
let body = Box::new(body);
Definition::FunctionDefinition(FunctionDefinition {
name,
parameters,
return_type,
return_type_span,
span,
body,
location: Location { line_col },
})
}
_ => panic!("unexpected node for definition: {:?}", pair.as_rule()),
}
}
fn make_span(&self, pair: &Pair<Rule>) -> Span {
let span = pair.as_span();
Span {
source: self.source,
start: span.start(),
end: span.end(),
}
}
}

View file

@ -1,4 +1,18 @@
mod backend;
mod tests;
pub use self::backend::pest::{parse_file, parse_module, parse_as_module};
use crate::ast::{Module, ModulePath, SourceId};
pub trait Parser: Default {
fn parse_file(&mut self, path: &std::path::Path, id: SourceId) -> anyhow::Result<Module>;
fn parse_as_module(
&mut self,
source: &str,
path: ModulePath,
id: SourceId,
) -> anyhow::Result<Module>;
}
pub use self::backend::pest::Parser as PestParser;
pub use PestParser as DefaultParser;

View file

@ -1,12 +1,17 @@
#[cfg(test)]
use pretty_assertions::assert_eq;
#[test]
fn test_addition_function() {
use crate::ast::{expr::Expr, *};
use crate::parsing::backend::pest::parse_as_module;
use crate::typing::Type;
use crate::ast::*;
use crate::parsing::*;
use crate::typing::*;
let source = "fn add(a: int, b: int) int { a + b }";
let path = ModulePath::from("test");
let module = parse_as_module(&source, path.clone()).expect("parsing error");
let module = DefaultParser::default()
.parse_as_module(source, path.clone(), 0)
.expect("parsing error");
let expected_module = Module {
file: None,
@ -26,21 +31,61 @@ fn test_addition_function() {
return_type: Some(Type::Int),
body: Box::new(Block {
statements: vec![],
value: Some(Expr::BinaryExpression {
lhs: Box::new(Expr::Identifier {
value: Some(SExpr {
expr: Expr::BinaryExpression(BinaryExpression {
lhs: Box::new(SExpr {
expr: Expr::Identifier {
name: Identifier::from("a"),
typ: Type::Undefined,
},
span: Span {
source: 0,
start: 29,
end: 30,
},
}),
op: BinaryOperator::Add,
rhs: Box::new(Expr::Identifier {
op_span: Span {
source: 0,
start: 31,
end: 32,
},
rhs: Box::new(SExpr {
expr: Expr::Identifier {
name: Identifier::from("b"),
typ: Type::Undefined,
},
span: Span {
source: 0,
start: 33,
end: 34,
},
}),
typ: Type::Undefined,
}),
typ: Type::Undefined,
span: Span {
source: 0,
start: 29,
end: 34,
},
}),
typ: Type::Undefined,
span: Some(Span {
source: 0,
start: 27,
end: source.len(),
}),
}),
span: Span {
source: 0,
start: 0,
end: source.len(),
},
return_type_span: Some(Span {
source: 0,
start: 23,
end: 26,
}),
location: Location { line_col: (1, 1) },
}],
path,
};

22
src/source.rs Normal file
View file

@ -0,0 +1,22 @@
use crate::ast::SourceId;
use ariadne::FileCache;
pub struct SourceCache {
pub paths: Vec<std::path::PathBuf>,
pub file_cache: FileCache,
}
impl ariadne::Cache<SourceId> for SourceCache {
type Storage = String;
fn fetch(
&mut self,
id: &SourceId,
) -> Result<&ariadne::Source<Self::Storage>, Box<dyn std::fmt::Debug + '_>> {
self.file_cache.fetch(&self.paths[*id as usize])
}
fn display<'a>(&self, id: &'a SourceId) -> Option<Box<dyn std::fmt::Display + 'a>> {
Some(Box::new(format!("{}", self.paths[*id as usize].display())))
}
}

View file

@ -1,8 +1,10 @@
use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypingContext};
use ariadne::{ColorGenerator, Fmt, Label, Report, ReportKind, Span as _};
use std::fmt::Debug;
use super::UnaryOperator;
use super::{Span, UnaryOperator};
use crate::typing::{BinaryOperator, Identifier, ModulePath, Type};
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct TypeError {
pub file: Option<std::path::PathBuf>,
pub module: ModulePath,
@ -10,72 +12,31 @@ pub struct TypeError {
pub kind: TypeErrorKind,
}
impl std::fmt::Display for TypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Error\n")?;
if let Some(path) = &self.file {
f.write_fmt(format_args!(" in file {}\n", path.display()))?;
}
f.write_fmt(format_args!(" in module {}\n", self.module))?;
if let Some(name) = &self.function {
f.write_fmt(format_args!(" in function {}\n", name))?;
}
f.write_fmt(format_args!("{:#?}", self.kind))?;
Ok(())
}
#[derive(PartialEq, Debug)]
pub struct TypeAndSpan {
pub ty: Type,
pub span: Span,
}
#[derive(Default)]
pub struct TypeErrorBuilder {
file: Option<std::path::PathBuf>,
module: Option<ModulePath>,
function: Option<String>,
kind: Option<TypeErrorKind>,
}
impl TypeError {
pub fn builder() -> TypeErrorBuilder {
TypeErrorBuilder::default()
}
}
impl TypeErrorBuilder {
pub fn context(mut self, ctx: &TypingContext) -> Self {
self.file = ctx.file.clone();
self.module = Some(ctx.module.clone());
self.function = ctx.function.clone();
self
}
pub fn kind(mut self, kind: TypeErrorKind) -> Self {
self.kind = Some(kind);
self
}
pub fn build(self) -> TypeError {
TypeError {
file: self.file,
module: self.module.expect("TypeError builder is missing module"),
function: self.function,
kind: self.kind.expect("TypeError builder is missing kind"),
}
}
#[derive(PartialEq, Debug)]
pub struct BinOpAndSpan {
pub op: BinaryOperator,
pub span: Span,
}
#[derive(Debug, PartialEq)]
pub enum TypeErrorKind {
InvalidBinaryOperator {
operator: BinaryOperator,
lht: Type,
rht: Type,
operator: BinOpAndSpan,
lhs: TypeAndSpan,
rhs: TypeAndSpan,
},
BlockTypeDoesNotMatchFunctionType {
block_type: Type,
function_type: Type,
},
ReturnTypeDoesNotMatchFunctionType {
function_type: Type,
return_type: Type,
return_expr: Option<TypeAndSpan>,
return_stmt: TypeAndSpan,
},
UnknownIdentifier {
identifier: String,
@ -86,7 +47,6 @@ pub enum TypeErrorKind {
},
AssignUndeclared,
VariableRedeclaration,
ReturnStatementsMismatch,
UnknownFunctionCalled(Identifier),
WrongFunctionArguments,
ConditionIsNotBool,
@ -96,3 +56,132 @@ pub enum TypeErrorKind {
inner: Type,
},
}
impl TypeError {
pub fn to_report(&self, ast: &crate::ast::Module) -> Report<Span> {
let mut colors = ColorGenerator::new();
let c0 = colors.next();
let c1 = colors.next();
colors.next();
let c2 = colors.next();
match &self.kind {
TypeErrorKind::InvalidBinaryOperator { operator, lhs, rhs } => {
Report::build(ReportKind::Error, 0u32, 0)
.with_message(format!(
"Invalid binary operation {} between {} and {}",
operator.op.to_string().fg(c0),
lhs.ty.to_string().fg(c1),
rhs.ty.to_string().fg(c2),
))
.with_labels([
Label::new(operator.span).with_color(c0),
Label::new(lhs.span)
.with_message(format!("This has type {}", lhs.ty.to_string().fg(c1)))
.with_color(c1)
.with_order(2),
Label::new(rhs.span)
.with_message(format!("This has type {}", rhs.ty.to_string().fg(c2)))
.with_color(c2)
.with_order(1),
])
.finish()
}
TypeErrorKind::BlockTypeDoesNotMatchFunctionType { block_type } => {
let function = ast
.functions
.iter()
.find(|f| f.name == *self.function.as_ref().unwrap())
.unwrap();
let block_color = c0;
let signature_color = c1;
let span = function.body.value.as_ref().unwrap().span;
let report = Report::build(ReportKind::Error, 0u32, 0)
.with_message("Function body does not match the signature")
.with_labels([
Label::new(function.body.span.unwrap())
.with_message("In this function's body")
.with_color(c2),
Label::new(span)
.with_message(format!(
"Returned expression has type {} but the function should return {}",
block_type.to_string().fg(block_color),
function
.return_type
.as_ref()
.unwrap_or(&Type::Unit)
.to_string()
.fg(signature_color)
))
.with_color(block_color),
]);
let report =
report.with_note("The last expression of a function's body is returned");
let report = if function.return_type.is_none() {
report.with_help(
"You may need to add the return type to the function's signature",
)
} else {
report
};
report.finish()
}
TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
return_expr,
return_stmt,
} => {
let function = ast
.functions
.iter()
.find(|f| f.name == *self.function.as_ref().unwrap())
.unwrap();
let is_bare_return = return_expr.is_none();
let report = Report::build(ReportKind::Error, *return_stmt.span.source(), 0)
.with_message("Return type does not match the function's signature")
.with_label(
Label::new(return_expr.as_ref().unwrap_or(return_stmt).span)
.with_color(c1)
.with_message(if is_bare_return {
format!("Bare return has type {}", Type::Unit.to_string().fg(c1))
} else {
format!(
"This expression has type {}",
return_stmt.ty.to_string().fg(c1)
)
}),
);
let report = if let Some(ret_ty_span) = function.return_type_span {
report.with_label(Label::new(ret_ty_span).with_color(c0).with_message(format!(
"The signature shows {}",
function.return_type.as_ref().unwrap().to_string().fg(c0)
)))
} else {
report
};
let report = if function.return_type.is_none() {
report.with_help(
"You may need to add the return type to the function's signature",
)
} else {
report
};
report.finish()
}
_ => todo!(),
}
}
}

View file

@ -1,11 +1,14 @@
use std::collections::HashMap;
use std::fmt::Display;
use BinaryExpression;
use ReturnStatement;
use crate::ast::ModulePath;
use crate::ast::*;
mod error;
use crate::typing::error::{TypeError, TypeErrorKind};
use crate::typing::error::{TypeAndSpan, TypeError, TypeErrorKind};
#[cfg(test)]
mod tests;
@ -62,11 +65,11 @@ impl From<&str> for Type {
#[derive(Debug, PartialEq, Clone)]
pub struct Signature(Vec<Type>, Type);
impl Into<Type> for Signature {
fn into(self) -> Type {
impl From<Signature> for Type {
fn from(val: Signature) -> Self {
Type::Function {
params: self.0,
returns: Box::new(self.1),
params: val.0,
returns: Box::new(val.1),
}
}
}
@ -79,12 +82,13 @@ impl FunctionDefinition {
}
}
#[derive(Debug, PartialEq)]
pub struct CheckedModule(pub Module);
impl Module {
pub fn type_check(mut self) -> Result<CheckedModule, TypeError> {
pub fn type_check(&mut self) -> Result<(), Vec<TypeError>> {
let mut ctx = TypingContext::new(self.path.clone());
ctx.file = self.file.clone();
ctx.file.clone_from(&self.file);
// Register all function signatures
for func in &self.functions {
@ -95,13 +99,21 @@ impl Module {
// TODO: add signatures of imported functions (even if they have not been checked)
let mut errors = Vec::new();
// Type-check the function bodies and complete all type placeholders
for func in &mut self.functions {
func.typ(&mut ctx)?;
if let Err(e) = func.typ(&mut ctx) {
errors.push(e);
};
ctx.variables.clear();
}
Ok(CheckedModule(self))
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
@ -128,6 +140,15 @@ impl TypingContext {
variables: Default::default(),
}
}
pub fn make_error(&self, kind: TypeErrorKind) -> TypeError {
TypeError {
kind,
file: self.file.clone(),
module: self.module.clone(),
function: self.function.clone(),
}
}
}
/// Trait for nodes which have a deducible type.
@ -148,40 +169,13 @@ impl TypeCheck for FunctionDefinition {
let body_type = self.body.typ(ctx)?;
// If the return type is not specified, it is unit.
if self.return_type.is_none() {
self.return_type = Some(Type::Unit)
}
// Check coherence with the body's type.
if *self.return_type.as_ref().unwrap() != body_type {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
if *self.return_type.as_ref().unwrap_or(&Type::Unit) != body_type {
return Err(
ctx.make_error(TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
block_type: body_type.clone(),
function_type: self.return_type.as_ref().unwrap().clone(),
})
.build());
}
// Check coherence with return statements.
for statement in &mut self.body.statements {
if let Statement::ReturnStatement(value) = statement {
let ret_type = match value {
Some(expr) => expr.typ(ctx)?,
None => Type::Unit,
};
if ret_type != *self.return_type.as_ref().unwrap() {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
function_type: self.return_type.as_ref().unwrap().clone(),
return_type: ret_type.clone(),
})
.build());
}
}
}),
);
}
Ok(self.return_type.clone().unwrap())
@ -190,79 +184,65 @@ impl TypeCheck for FunctionDefinition {
impl TypeCheck for Block {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
let mut return_typ: Option<Type> = None;
// Check declarations and assignments.
for statement in &mut self.statements {
match statement {
Statement::DeclareStatement(ident, expr) => {
Statement::DeclareStatement {
lhs: ident,
rhs: expr,
..
} => {
let typ = expr.typ(ctx)?;
if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) {
// TODO: Shadowing? (illegal for now)
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::VariableRedeclaration)
.build());
return Err(ctx.make_error(TypeErrorKind::VariableRedeclaration));
}
}
Statement::AssignStatement(ident, expr) => {
Statement::AssignStatement {
lhs: ident,
rhs: expr,
..
} => {
let rhs_typ = expr.typ(ctx)?;
let Some(lhs_typ) = ctx.variables.get(ident) else {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::AssignUndeclared)
.build());
return Err(ctx.make_error(TypeErrorKind::AssignUndeclared));
};
// Ensure same type on both sides.
if rhs_typ != *lhs_typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::AssignmentMismatch {
return Err(ctx.make_error(TypeErrorKind::AssignmentMismatch {
lht: lhs_typ.clone(),
rht: rhs_typ.clone(),
})
.build());
}));
}
}
Statement::ReturnStatement(maybe_expr) => {
let expr_typ = if let Some(expr) = maybe_expr {
expr.typ(ctx)?
} else {
Type::Unit
};
if let Some(typ) = &return_typ {
if expr_typ != *typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnStatementsMismatch)
.build());
Statement::ReturnStatement(return_stmt) => {
return_stmt.typ(ctx)?;
}
} else {
return_typ = Some(expr_typ.clone());
}
}
Statement::CallStatement(call) => {
Statement::CallStatement { call, span: _ } => {
call.typ(ctx)?;
}
Statement::UseStatement(_path) => {
Statement::UseStatement { .. } => {
// TODO: import the signatures (and types)
todo!()
}
Statement::IfStatement(cond, block) => {
Statement::IfStatement {
condition: cond,
then_block: block,
..
} => {
if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
return Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool));
}
block.typ(ctx)?;
}
Statement::WhileStatement(cond, block) => {
Statement::WhileStatement {
condition: cond,
loop_block: block,
span: _,
} => {
if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
return Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool));
}
block.typ(ctx)?;
}
@ -277,25 +257,19 @@ impl TypeCheck for Block {
self.typ = Type::Unit;
Ok(Type::Unit)
}
// TODO/FIXME: find a way to return `return_typ` so that the
// top-level block (the function) can check if this return type
// (and eventually those from other block) matches the type of
// the function.
}
}
impl TypeCheck for Call {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
match &mut *self.callee {
Expr::Identifier { name, typ } => {
match &mut self.callee.expr {
Expr::Identifier { name, typ, .. } => {
let signature = match ctx.functions.get(name) {
Some(sgn) => sgn.clone(),
None => {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::UnknownFunctionCalled(name.clone()))
.build())
return Err(
ctx.make_error(TypeErrorKind::UnknownFunctionCalled(name.clone()))
)
}
};
@ -315,10 +289,7 @@ impl TypeCheck for Call {
if args_types == *params_types {
Ok(self.typ.clone())
} else {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::WrongFunctionArguments)
.build())
Err(ctx.make_error(TypeErrorKind::WrongFunctionArguments))
}
}
_ => unimplemented!("cannot call on expression other than identifier"),
@ -329,36 +300,40 @@ impl TypeCheck for Call {
impl TypeCheck for Expr {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
match self {
Expr::Identifier { name, typ } => {
Expr::Identifier { name, typ, .. } => {
if let Some(ty) = ctx.variables.get(name) {
*typ = ty.clone();
Ok(typ.clone())
} else {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::UnknownIdentifier {
Err(ctx.make_error(TypeErrorKind::UnknownIdentifier {
identifier: name.clone(),
})
.build())
}))
}
}
Expr::BooleanLiteral(_) => Ok(Type::Bool),
Expr::IntegerLiteral(_) => Ok(Type::Int),
Expr::FloatLiteral(_) => Ok(Type::Float),
Expr::UnaryExpression { op, inner } => {
Expr::BooleanLiteral(..) => Ok(Type::Bool),
Expr::IntegerLiteral(..) => Ok(Type::Int),
Expr::FloatLiteral(..) => Ok(Type::Float),
Expr::UnaryExpression { op, inner, .. } => {
let inner_type = &inner.typ(ctx)?;
match (&op, inner_type) {
(UnaryOperator::Not, Type::Bool) => Ok(Type::Bool),
_ => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidUnaryOperator {
_ => Err(ctx.make_error(TypeErrorKind::InvalidUnaryOperator {
operator: *op,
inner: inner_type.clone(),
})
.build()),
})),
}
}
Expr::BinaryExpression { lhs, op, rhs, typ } => {
Expr::BinaryExpression(BinaryExpression {
lhs,
op,
rhs,
typ,
op_span,
}) => {
let operator = error::BinOpAndSpan {
op: op.clone(),
span: *op_span,
};
let ty = match op {
BinaryOperator::Add
| BinaryOperator::Sub
@ -368,32 +343,39 @@ impl TypeCheck for Expr {
| BinaryOperator::Or => {
let left_type = &lhs.typ(ctx)?;
let right_type = &rhs.typ(ctx)?;
match (left_type, right_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
(Type::Float, Type::Float) => Ok(Type::Float),
(Type::Bool, Type::Bool) => Ok(Type::Bool),
(_, _) => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: left_type.clone(),
rht: right_type.clone(),
})
.build()),
(_, _) => Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator {
operator,
lhs: TypeAndSpan {
ty: left_type.clone(),
span: lhs.span,
},
rhs: TypeAndSpan {
ty: right_type.clone(),
span: rhs.span,
},
})),
}
}
BinaryOperator::Equal | BinaryOperator::NotEqual => {
let lhs_type = lhs.typ(ctx)?;
let rhs_type = rhs.typ(ctx)?;
if lhs_type != rhs_type {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: lhs_type.clone(),
rht: rhs_type.clone(),
})
.build());
return Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator {
operator,
lhs: TypeAndSpan {
ty: lhs_type.clone(),
span: lhs.span,
},
rhs: TypeAndSpan {
ty: rhs_type.clone(),
span: rhs.span,
},
}));
}
Ok(Type::Bool)
}
@ -402,14 +384,17 @@ impl TypeCheck for Expr {
let rhs_type = lhs.typ(ctx)?;
match (&lhs_type, &rhs_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
_ => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: lhs_type.clone(),
rht: rhs_type.clone(),
})
.build()),
_ => Err(ctx.make_error(TypeErrorKind::InvalidBinaryOperator {
operator,
lhs: TypeAndSpan {
ty: lhs_type.clone(),
span: lhs.span,
},
rhs: TypeAndSpan {
ty: rhs_type.clone(),
span: rhs.span,
},
})),
}
}
};
@ -427,18 +412,12 @@ impl TypeCheck for Expr {
typ,
} => {
if cond.typ(ctx)? != Type::Bool {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build())
Err(ctx.make_error(TypeErrorKind::ConditionIsNotBool))
} else {
let then_body_type = then_body.typ(ctx)?;
let else_type = else_body.typ(ctx)?;
if then_body_type != else_type {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::IfElseMismatch)
.build())
Err(ctx.make_error(TypeErrorKind::IfElseMismatch))
} else {
// XXX: opt: return ref to avoid cloning
*typ = then_body_type.clone();
@ -449,3 +428,39 @@ impl TypeCheck for Expr {
}
}
}
impl TypeCheck for ReturnStatement {
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
let ty = if let Some(expr) = &mut self.expr {
expr.typ(ctx)?
} else {
Type::Unit
};
// Check if the returned type is coherent with the function's signature
let func_type = &ctx.functions.get(ctx.function.as_ref().unwrap()).unwrap().1;
if ty != *func_type {
return Err(
ctx.make_error(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
return_expr: self.expr.as_ref().map(|e| TypeAndSpan {
ty: ty.clone(),
span: e.span,
}),
return_stmt: TypeAndSpan {
ty: ty.clone(),
span: self.span,
},
}),
);
};
Ok(ty)
}
}
impl TypeCheck for SExpr {
#[inline]
fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
self.expr.typ(ctx)
}
}

View file

@ -1,33 +1,40 @@
use crate::{
ast::ModulePath,
parsing::parse_as_module,
typing::{
error::{TypeError, TypeErrorKind},
BinaryOperator, Type,
},
parsing::{DefaultParser, Parser},
typing::error::*,
typing::*,
};
#[cfg(test)]
use pretty_assertions::assert_eq;
#[test]
fn addition_int_and_float() {
let source = "fn add(a: int, b: float) int { a + b }";
let mut ast = parse_as_module(source, ModulePath::default()).unwrap();
let mut ast = DefaultParser::default()
.parse_as_module(source, ModulePath::default(), 0)
.unwrap();
let res = ast.type_check();
assert!(res.is_err_and(|e| e.kind
== TypeErrorKind::InvalidBinaryOperator {
operator: BinaryOperator::Add,
lht: Type::Int,
rht: Type::Float
}));
assert!(res.is_err_and(|errors| errors.len() == 1
&& matches!(errors[0].kind, TypeErrorKind::InvalidBinaryOperator { .. })));
}
#[test]
fn return_int_instead_of_float() {
let source = "fn add(a: int, b: int) float { a + b }";
let mut ast = parse_as_module(source, ModulePath::default()).unwrap();
let mut ast = DefaultParser::default()
.parse_as_module(source, ModulePath::default(), 0)
.unwrap();
let res = ast.type_check();
assert!(res.is_err_and(|e| e.kind
== TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
assert_eq!(
res,
Err(vec![TypeError {
file: None,
module: ModulePath::default(),
function: Some("add".to_string()),
kind: TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
block_type: Type::Int,
function_type: Type::Float
}));
}
}])
);
}