basic jit

This commit is contained in:
Romain Paquet 2024-03-08 17:38:23 +01:00
parent 374daaff7f
commit 511be952aa
16 changed files with 971 additions and 495 deletions

View file

@ -5,10 +5,10 @@ edition = "2021"
[dependencies] [dependencies]
clap = { version = "4.3.0", features = ["derive"] } clap = { version = "4.3.0", features = ["derive"] }
cranelift = "0.100.0" cranelift = "0.105.1"
cranelift-jit = "0.96.3" cranelift-jit = "0.105.1"
cranelift-module = "0.96.3" cranelift-module = "0.105.1"
cranelift-native = "0.96.3" cranelift-native = "0.105.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"
pest = "2.6.0" pest = "2.7.4"
pest_derive = "2.6.0" pest_derive = "2.7.4"

34
src/ast/expr.rs Normal file
View file

@ -0,0 +1,34 @@
use crate::ast::*;
#[derive(Debug, PartialEq)]
pub enum Expr {
BinaryExpression {
lhs: Box<Expr>,
op: BinaryOperator,
rhs: Box<Expr>,
typ: Type,
},
UnaryExpression {
op: UnaryOperator,
inner: Box<Expr>,
},
Identifier {
name: String,
typ: Type,
},
Call(Box<Call>),
Block(Box<Block>),
/// Last field is either Expr::Block or Expr::IfExpr
IfExpr {
cond: Box<Expr>,
then_body: Box<Block>,
else_body: Box<Expr>,
typ: Type,
},
// Literals
UnitLiteral,
BooleanLiteral(bool),
IntegerLiteral(i64),
FloatLiteral(f64),
StringLiteral(String),
}

View file

@ -1,10 +1,17 @@
use crate::typing::Type;
use std::path::Path; use std::path::Path;
pub mod expr;
pub mod typed; pub mod typed;
pub mod untyped;
pub use expr::Expr;
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum BinaryOperator { pub enum BinaryOperator {
// Logic
And,
Or,
// Arithmetic
Add, Add,
Sub, Sub,
Mul, Mul,
@ -14,13 +21,19 @@ pub enum BinaryOperator {
NotEqual, NotEqual,
} }
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum UnaryOperator { pub enum UnaryOperator {
Not,
} }
pub type Identifier = String; pub type Identifier = String;
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq)]
pub struct Location {
pub line_col: (usize, usize),
}
#[derive(Debug, PartialEq, Clone, Default)]
pub struct ModulePath { pub struct ModulePath {
components: Vec<String>, components: Vec<String>,
} }
@ -65,3 +78,75 @@ impl From<&str> for ModulePath {
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
pub struct Import(pub String); pub struct Import(pub String);
#[derive(Debug, PartialEq)]
pub enum Statement {
DeclareStatement(Identifier, Box<Expr>),
AssignStatement(Identifier, Box<Expr>),
ReturnStatement(Option<Expr>),
CallStatement(Box<Call>),
UseStatement(Box<Import>),
IfStatement(Box<Expr>, Box<Block>),
WhileStatement(Box<Expr>, Box<Block>),
}
#[derive(Debug, PartialEq)]
pub struct Block {
pub statements: Vec<Statement>,
pub value: Option<Expr>,
pub typ: Type,
}
impl Block {
pub fn empty() -> Block {
Block {
typ: Type::Unit,
statements: Vec::with_capacity(0),
value: None,
}
}
}
#[derive(Debug, PartialEq)]
pub enum Definition {
FunctionDefinition(FunctionDefinition),
}
#[derive(Debug, PartialEq)]
pub struct FunctionDefinition {
pub name: Identifier,
pub parameters: Vec<Parameter>,
pub return_type: Option<Type>,
pub body: Box<Block>,
pub location: Location,
}
#[derive(Debug, PartialEq, Default)]
pub struct Module {
pub file: Option<std::path::PathBuf>,
pub path: ModulePath,
pub functions: Vec<FunctionDefinition>,
pub imports: Vec<Import>,
}
impl Module {
pub fn new(path: ModulePath) -> Self {
Self {
path,
..Default::default()
}
}
}
#[derive(Debug, PartialEq)]
pub struct Call {
pub callee: Box<Expr>,
pub args: Vec<Expr>,
pub typ: Type,
}
#[derive(Debug, PartialEq)]
pub struct Parameter {
pub name: Identifier,
pub typ: Type,
}

View file

@ -3,60 +3,20 @@ use crate::ast::{
BinaryOperator, UnaryOperator, BinaryOperator, UnaryOperator,
}; };
use crate::typing::Type; use crate::typing::Type;
#[derive(Debug, PartialEq)]
pub enum Expr {
BinaryExpression {
lhs: Box<Expr>,
op: BinaryOperator,
rhs: Box<Expr>,
typ: Type,
},
UnaryExpression {
op: UnaryOperator,
inner: Box<Expr>,
},
Variable {
name: String,
typ: Type,
},
Call {
call: Box<Call>,
typ: Type,
},
Block {
block: Box<Block>,
typ: Type,
},
/// Last field is either Expr::Block or Expr::IfExpr
IfExpr {
cond: Box<Expr>,
then_body: Box<Block>,
else_body: Box<Expr>,
typ: Type,
},
// Literals
UnitLiteral,
BooleanLiteral(bool),
IntegerLiteral(i64),
FloatLiteral(f64),
StringLiteral(String),
}
impl Expr { impl Expr {
pub fn typ(&self) -> &Type { pub fn typ(&self) -> &Type {
match self { match self {
Expr::BinaryExpression { lhs, op, rhs, typ } => typ, Expr::BinaryExpression { lhs, op, rhs, typ } => &typ.unwrap(),
Expr::UnaryExpression { op, inner } => inner.typ(), // XXX: problems will arise here Expr::UnaryExpression { op, inner } => inner.typ(), // XXX: problems will arise here
Expr::Variable { name, typ } => typ, Expr::Variable { name, typ } => &typ.unwrap(),
Expr::Call { call, typ } => typ, Expr::Call { call, typ } => &typ.unwrap(),
Expr::Block { block, typ } => typ, Expr::Block { block, typ } => &typ.unwrap(),
Expr::IfExpr { Expr::IfExpr {
cond, cond,
then_body, then_body,
else_body, else_body,
typ, typ,
} => typ, } => &typ.unwrap(),
Expr::UnitLiteral => &Type::Unit, Expr::UnitLiteral => &Type::Unit,
Expr::BooleanLiteral(_) => &Type::Bool, Expr::BooleanLiteral(_) => &Type::Bool,
Expr::IntegerLiteral(_) => &Type::Int, Expr::IntegerLiteral(_) => &Type::Int,

View file

@ -1,47 +1,37 @@
pub mod expr; use crate::ast::*;
use crate::typing::Type;
use super::{untyped::Parameter, Identifier, Import};
use expr::Expr;
#[derive(Debug, PartialEq)]
pub enum Statement {
DeclareStatement(Identifier, Box<Expr>),
AssignStatement(Identifier, Box<Expr>),
ReturnStatement(Option<Expr>),
CallStatement(Box<Call>),
UseStatement(Box<Import>),
IfStatement(Box<Expr>, Box<Block>),
WhileStatement(Box<Expr>, Box<Block>),
}
#[derive(Debug, PartialEq)]
pub struct Block {
pub statements: Vec<Statement>,
pub value: Option<Expr>,
typ: Type,
}
impl Block { impl Block {
#[inline] #[inline]
pub fn typ(&self) -> Type { pub fn ty(&self) -> Type {
// XXX: Cloning may be expensive -> TypeId?
self.typ.clone() self.typ.clone()
} }
} }
#[derive(Debug, PartialEq)] impl Expr {
pub struct FunctionDefinition { pub fn ty(&self) -> Type {
pub name: Identifier, match self {
pub parameters: Vec<Parameter>, Expr::BinaryExpression {
pub return_type: Option<Type>, lhs: _,
pub body: Box<Block>, op: _,
pub line_col: (usize, usize), rhs: _,
typ,
} => typ.clone(),
Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here
Expr::Identifier { name: _, typ } => typ.clone(),
Expr::Call(call) => call.typ.clone(),
Expr::Block(block) => block.typ.clone(),
Expr::IfExpr {
cond: _,
then_body: _,
else_body: _,
typ,
} => typ.clone(),
Expr::UnitLiteral => Type::Unit,
Expr::BooleanLiteral(_) => Type::Bool,
Expr::IntegerLiteral(_) => Type::Int,
Expr::FloatLiteral(_) => Type::Float,
Expr::StringLiteral(_) => Type::Str,
}
}
} }
#[derive(Debug, PartialEq)]
pub struct Call {
pub callee: Box<Expr>,
pub args: Vec<Expr>,
pub typ: Type,
}

View file

@ -1,22 +0,0 @@
use crate::ast::{
untyped::{Block, Call},
Identifier,
};
use crate::ast::*;
#[derive(Debug, PartialEq)]
pub enum Expr {
UnitLiteral,
BinaryExpression(Box<Expr>, BinaryOperator, Box<Expr>),
Identifier(Identifier),
Call(Box<Call>),
// Literals
BooleanLiteral(bool),
IntegerLiteral(i64),
FloatLiteral(f64),
StringLiteral(String),
Block(Box<Block>),
/// Last field is either Expr::Block or Expr::IfExpr
IfExpr(Box<Expr>, Box<Block>, Box<Expr>),
}

View file

@ -1,61 +0,0 @@
pub mod expr;
pub mod module;
use std::path::Path;
pub use crate::ast::untyped::expr::Expr;
pub use crate::ast::*;
// TODO: remove all usage of 'Type' in the untyped ast
// (for now it is assumed that anything that parses
// is a Type, but the checking should be done in the typing
// phase)
use crate::typing::Type;
#[derive(Debug, PartialEq)]
pub enum Definition {
FunctionDefinition(FunctionDefinition),
//StructDefinition(StructDefinition),
}
#[derive(Debug, PartialEq)]
pub struct Location {
pub file: Box<Path>,
}
#[derive(Debug, PartialEq)]
pub struct FunctionDefinition {
pub name: Identifier,
pub parameters: Vec<Parameter>,
pub return_type: Option<Type>,
pub body: Box<Block>,
pub line_col: (usize, usize),
}
#[derive(Debug, PartialEq)]
pub struct Block {
pub statements: Vec<Statement>,
pub value: Option<Expr>,
}
#[derive(Debug, PartialEq)]
pub enum Statement {
DeclareStatement(Identifier, Expr),
AssignStatement(Identifier, Expr),
ReturnStatement(Option<Expr>),
CallStatement(Call),
UseStatement(Import),
IfStatement(Expr, Block),
WhileStatement(Box<Expr>, Box<Block>),
}
#[derive(Debug, PartialEq)]
pub struct Call {
pub callee: Box<Expr>,
pub args: Vec<Expr>,
}
#[derive(Debug, PartialEq)]
pub struct Parameter {
pub name: Identifier,
pub typ: Type,
}

View file

@ -1,20 +0,0 @@
use super::{Definition, ModulePath, Import};
#[derive(Debug, PartialEq)]
pub struct Module {
pub file: Option<std::path::PathBuf>,
pub path: ModulePath,
pub definitions: Vec<Definition>,
pub imports: Vec<Import>,
}
impl Module {
pub fn new(path: ModulePath) -> Self {
Module {
path,
file: None,
definitions: vec![],
imports: vec![],
}
}
}

465
src/jit/mod.rs Normal file
View file

@ -0,0 +1,465 @@
use crate::{
ast::{
self, BinaryOperator, ModulePath, UnaryOperator,
{expr::Expr, FunctionDefinition, Statement},
},
parsing,
typing::Type,
};
use cranelift::{codegen::ir::UserFuncName, prelude::*};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module};
use std::{collections::HashMap, ops::Deref};
/// The basic JIT class.
pub struct JIT {
/// The function builder context, which is reused across multiple
/// FunctionBuilder instances.
builder_context: FunctionBuilderContext,
/// The main Cranelift context, which holds the state for codegen. Cranelift
/// separates this from `Module` to allow for parallel compilation, with a
/// context per thread, though this isn't in the simple demo here.
ctx: codegen::Context,
/// The data description, which is to data objects what `ctx` is to functions.
data_desc: DataDescription,
/// The module, with the jit backend, which manages the JIT'd functions.
module: JITModule,
}
impl Default for JIT {
fn default() -> Self {
let mut flag_builder = codegen::settings::builder();
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
panic!("host machine is not supported: {}", msg);
});
let isa = isa_builder
.finish(settings::Flags::new(flag_builder))
.unwrap();
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
let module = JITModule::new(builder);
Self {
builder_context: FunctionBuilderContext::new(),
ctx: module.make_context(),
data_desc: DataDescription::new(),
module,
}
}
}
impl JIT {
/// Compile source code into machine code.
pub fn compile(&mut self, input: &str, dump_clir: bool) -> Result<*const u8, String> {
// Parse the source code into an AST
let Ok(mut ast) = parsing::parse_as_module(input, ModulePath::from("globalmodule")) else {
return Err("Parsing error".to_string());
};
if let Err(e) = ast.type_check() {
return Err(e.to_string());
};
// Translate the AST into Cranelift IR
self.translate(&ast, dump_clir)?;
// Finalize the functions which we just defined, which resolves any
// outstanding relocations (patching in addresses, now that they're
// available).
self.module.finalize_definitions().unwrap();
// We can now retrieve a pointer to the machine code.
if let Some(FuncOrDataId::Func(main_id)) = self.module.get_name("main") {
let code = self.module.get_finalized_function(main_id);
Ok(code)
} else {
Err("no main function".into())
}
}
/// Translate language AST into Cranelift IR.
fn translate(&mut self, ast: &ast::Module, dump_clir: bool) -> Result<(), String> {
let mut signatures: Vec<Signature> = Vec::with_capacity(ast.functions.len());
let mut func_ids: Vec<FuncId> = Vec::with_capacity(ast.functions.len());
// Declare functions
for func in &ast.functions {
// Create the signature
let mut sig = self.module.make_signature();
for param in &func.parameters {
assert!(param.typ != Type::Unit);
sig.params.append(&mut Vec::from(&param.typ));
}
if let Some(return_type) = &func.return_type {
if *return_type != Type::Unit {
sig.returns = return_type.into();
}
};
let id: FuncId = self
.module
.declare_function(&func.name, Linkage::Export, &sig)
.map_err(|e| e.to_string())?;
signatures.push(sig);
func_ids.push(id);
}
// Translate functions
for (i, func) in ast.functions.iter().enumerate() {
self.ctx.func.signature = signatures[i].clone();
self.ctx.func.name = UserFuncName::user(0, func_ids[i].as_u32());
self.translate_function(func)?;
self.module
.define_function(func_ids[i], &mut self.ctx)
.unwrap();
if dump_clir {
println!("// {}", func.name);
println!("{}", self.ctx.func.display());
}
self.module.clear_context(&mut self.ctx);
}
Ok(())
}
fn translate_function(&mut self, function: &FunctionDefinition) -> Result<(), String> {
// Create the builder to build a function.
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
// Create the entry block, to start emitting code in.
let entry_block = builder.create_block();
// Since this is the entry block, add block parameters corresponding to
// the function's parameters.
builder.append_block_params_for_function_params(entry_block);
// Tell the builder to emit code in this block.
builder.switch_to_block(entry_block);
// And, tell the builder that this block will have no further
// predecessors. Since it's the entry block, it won't have any
// predecessors.
builder.seal_block(entry_block);
// Walk the AST and declare all implicitly-declared variables.
let mut variables = HashMap::<String, Variable>::default(); // TODO: actually do this
// Add a variable for each parameter.
let param_values: Box<[Value]> = builder.block_params(entry_block).into();
assert!(param_values.len() == function.parameters.len());
for (i, param) in function.parameters.iter().enumerate() {
let var = Variable::from_u32(variables.len() as u32);
variables.insert(param.name.clone(), var);
let value = param_values[i];
builder.declare_var(var, param.typ.clone().into());
builder.def_var(var, value);
}
// Now translate the statements of the function body.
let mut translator = FunctionTranslator {
builder,
variables,
module: &mut self.module,
};
for stmt in &function.body.statements {
translator.translate_statement(stmt);
}
// Emit the final return instruction.
if let Some(return_expr) = &function.body.value {
let return_value = translator.translate_expr(&return_expr);
translator.builder.ins().return_(&[return_value]);
} else {
translator.builder.ins().return_(&[]);
}
// Tell the builder we're done with this function.
translator.builder.finalize();
Ok(())
}
}
impl From<crate::typing::Type> for types::Type {
fn from(value: crate::typing::Type) -> Self {
match value {
Type::Bool => types::I8,
Type::Int => types::I32,
Type::Float => types::F32,
Type::Unit => unreachable!(),
Type::Str => todo!(),
Type::Custom(_) => todo!(),
Type::Function {
params: _,
returns: _,
} => todo!(),
Type::Undefined => unreachable!(),
}
}
}
impl From<&Type> for Vec<AbiParam> {
fn from(value: &Type) -> Self {
match value {
Type::Bool => vec![AbiParam::new(types::I8)],
Type::Int => vec![AbiParam::new(types::I32)],
Type::Float => vec![AbiParam::new(types::F32)],
_ => unimplemented!(),
}
}
}
/// A collection of state used for translating from AST nodes
/// into Cranelift IR.
struct FunctionTranslator<'a> {
builder: FunctionBuilder<'a>,
variables: HashMap<String, Variable>,
module: &'a mut JITModule,
}
impl<'a> FunctionTranslator<'a> {
fn translate_statement(&mut self, stmt: &Statement) -> Option<Value> {
match stmt {
Statement::AssignStatement(name, expr) => {
// `def_var` is used to write the value of a variable. Note that
// variables can have multiple definitions. Cranelift will
// convert them into SSA form for itself automatically.
let new_value = self.translate_expr(expr);
let variable = self.variables.get(name).unwrap();
self.builder.def_var(*variable, new_value);
Some(new_value)
}
Statement::DeclareStatement(name, expr) => {
let value = self.translate_expr(expr);
let variable = Variable::from_u32(self.variables.len() as u32);
self.builder.declare_var(variable, expr.ty().into());
self.builder.def_var(variable, value);
self.variables.insert(name.clone(), variable);
Some(value)
}
Statement::ReturnStatement(maybe_expr) => {
// TODO: investigate tail call
let values = if let Some(expr) = maybe_expr {
vec![self.translate_expr(expr)]
} else {
// XXX: urgh
Vec::with_capacity(0)
};
// XXX: Should we pass multiple values ?
self.builder.ins().return_(&values);
None
}
Statement::CallStatement(call) => self.translate_call(call),
Statement::UseStatement(_) => todo!(),
Statement::IfStatement(cond, then_body) => {
let condition_value = self.translate_expr(cond);
let then_block = self.builder.create_block();
let merge_block = self.builder.create_block();
self.builder
.ins()
.brif(condition_value, then_block, &[], merge_block, &[]);
self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block);
self.translate_block(then_body);
self.builder.ins().jump(merge_block, &[]);
self.builder.switch_to_block(merge_block);
self.builder.seal_block(merge_block);
None
}
Statement::WhileStatement(_, _) => todo!(),
}
}
fn translate_expr(&mut self, expr: &Expr) -> Value {
match expr {
Expr::UnitLiteral => unreachable!(),
Expr::BooleanLiteral(imm) => self.builder.ins().iconst(types::I8, i64::from(*imm)),
Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm),
Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm),
Expr::StringLiteral(_) => todo!(),
Expr::BinaryExpression {
lhs,
op,
rhs,
typ: _,
} => {
let lhs_value = self.translate_expr(lhs);
let rhs_value = self.translate_expr(rhs);
match (lhs.ty(), lhs.ty()) {
(Type::Int, Type::Int) => match op {
BinaryOperator::Add => self.builder.ins().iadd(lhs_value, rhs_value),
BinaryOperator::Sub => self.builder.ins().isub(lhs_value, rhs_value),
BinaryOperator::Mul => self.builder.ins().imul(lhs_value, rhs_value),
// TODO: investigate division (case rhs <= 0)
BinaryOperator::Div => self.builder.ins().udiv(lhs_value, rhs_value),
BinaryOperator::Modulo => todo!(),
BinaryOperator::Equal => {
self.builder.ins().icmp(IntCC::Equal, lhs_value, rhs_value)
}
BinaryOperator::NotEqual => {
self.builder
.ins()
.icmp(IntCC::NotEqual, lhs_value, rhs_value)
}
_ => unreachable!(),
},
(Type::Bool, Type::Bool) => match op {
// XXX: Is min and max ok or should it be something else?
BinaryOperator::And => self.builder.ins().umin(lhs_value, rhs_value),
BinaryOperator::Or => self.builder.ins().umax(lhs_value, rhs_value),
_ => unreachable!(),
},
_ => unimplemented!(),
}
}
Expr::IfExpr {
cond,
then_body,
else_body,
typ,
} => {
let condition_value = self.translate_expr(cond);
let then_block = self.builder.create_block();
let else_block = self.builder.create_block();
let merge_block = self.builder.create_block();
// If-else constructs in the language have a return value.
// In traditional SSA form, this would produce a PHI between
// the then and else bodies. Cranelift uses block parameters,
// so set up a parameter in the merge block, and we'll pass
// the return values to it from the branches.
self.builder
.append_block_param(merge_block, typ.clone().into());
// Test the if condition and conditionally branch.
self.builder
.ins()
.brif(condition_value, then_block, &[], else_block, &[]);
self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block);
for stmt in &then_body.statements {
self.translate_statement(&stmt);
}
let then_return_value = match &then_body.value {
Some(val) => vec![self.translate_expr(val)],
None => Vec::with_capacity(0),
};
// Jump to the merge block, passing it the block return value.
self.builder.ins().jump(merge_block, &then_return_value);
self.builder.switch_to_block(else_block);
self.builder.seal_block(else_block);
// XXX: the else can be just an expression: do we always need to
// make a second branch in that case? Or leave it to cranelift?
let else_return_value = match **else_body {
Expr::UnitLiteral => Vec::with_capacity(0),
_ => vec![self.translate_expr(else_body)],
};
// Jump to the merge block, passing it the block return value.
self.builder.ins().jump(merge_block, &else_return_value);
// Switch to the merge block for subsequent statements.
self.builder.switch_to_block(merge_block);
// We've now seen all the predecessors of the merge block.
self.builder.seal_block(merge_block);
// Read the value of the if-else by reading the merge block
// parameter.
let phi = self.builder.block_params(merge_block)[0];
phi
}
Expr::UnaryExpression { op, inner } => {
let inner_value = self.translate_expr(inner);
match op {
// XXX: This should not be a literal translation
UnaryOperator::Not => {
let one = self.builder.ins().iconst(types::I8, 1);
self.builder.ins().isub(one, inner_value)
}
}
}
Expr::Identifier { name, typ: _ } => {
self.builder.use_var(*self.variables.get(name).unwrap())
}
Expr::Call(call) => self.translate_call(call).unwrap(),
Expr::Block(block) => self.translate_block(block).unwrap(),
}
}
fn translate_block(&mut self, block: &ast::Block) -> Option<Value> {
for stmt in &block.statements {
self.translate_statement(stmt);
}
if let Some(block_value) = &block.value {
Some(self.translate_expr(block_value))
} else {
None
}
}
fn translate_call(&mut self, call: &ast::Call) -> Option<Value> {
match call.callee.deref() {
Expr::Identifier { name, typ: _ } => {
let func_ref = if let Some(func_or_data_id) = self.module.get_name(name.as_ref()) {
if let FuncOrDataId::Func(func_id) = func_or_data_id {
self.module.declare_func_in_func(func_id, self.builder.func)
} else {
panic!()
}
} else {
todo!()
};
let args: Vec<Value> = call.args.iter().map(|a| self.translate_expr(a)).collect();
// TODO: handle the return value of the function
let call_inst = self.builder.ins().call(func_ref, &args);
let results = self.builder.inst_results(call_inst);
Some(results[0])
}
_ => unimplemented!(),
}
}
}

View file

@ -1,4 +0,0 @@
mod ast;
mod typing;
mod jit;
mod parsing;

View file

@ -1,10 +1,11 @@
mod ast; pub mod ast;
mod parsing; pub mod jit;
mod typing; pub mod parsing;
pub mod typing;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use crate::ast::untyped::module::Module; use crate::ast::Module;
/// Experimental compiler for lila /// Experimental compiler for lila
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -18,53 +19,80 @@ struct Cli {
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
enum Commands { enum Commands {
Parse { Parse {
/// Path to the source files /// Paths to the source files
files: Vec<String>, files: Vec<String>,
/// Dump the AST to stdout /// Dump the AST to stdout
#[arg(long)] #[arg(long)]
dump_ast: bool, dump_ast: bool,
/// Add missing return types in the AST
#[arg(long)]
type_check: bool,
}, },
TypeCheck {
/// Paths to the source files
files: Vec<String>,
/// Dump the AST to stdout
#[arg(long)]
dump_ast: bool,
},
Compile {
/// Paths to the source files
files: Vec<String>,
/// Dump the CLIR to stdout
#[arg(long)]
dump_clir: bool,
},
}
fn parse(files: &Vec<String>) -> Vec<Module> {
let paths = files.iter().map(std::path::Path::new);
paths
.map(|path| match parsing::parse_file(&path) {
Ok(module) => module,
Err(e) => panic!("Parsing error: {:#?}", e),
})
.collect()
}
fn check(modules: &mut Vec<Module>) {
for module in modules {
if let Err(e) = module.type_check() {
eprintln!("{}", e);
return;
}
}
} }
fn main() { fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
match &cli.command { match &cli.command {
Commands::Parse { Commands::Parse { files, dump_ast } => {
files, let modules = parse(files);
dump_ast,
type_check,
} => {
let paths = files.iter().map(std::path::Path::new);
let modules: Vec<Module> = paths
.map(|path| match parsing::parse_file(&path) {
Ok(module) => module,
Err(e) => panic!("Parsing error: {:#?}", e),
})
.collect();
if *type_check {
for module in &modules {
if let Err(e) = module.type_check() {
eprintln!("{}", e);
return;
}
}
}
if *dump_ast { if *dump_ast {
for module in &modules { for module in &modules {
println!("{:#?}", &module); println!("{:#?}", &module);
} }
return;
} }
println!("Parsing OK"); println!("Parsing OK");
} }
Commands::TypeCheck { files, dump_ast } => {
let mut modules = parse(files);
check(&mut modules);
if *dump_ast {
for module in &modules {
println!("{:#?}", &module);
}
}
}
Commands::Compile { files, dump_clir } => {
let mut jit = jit::JIT::default();
for file in files {
match jit.compile(std::fs::read_to_string(file).unwrap().as_str(), *dump_clir) {
Err(e) => eprintln!("{}", e),
Ok(_code) => println!("Compiled {}", file),
}
}
}
} }
} }

View file

@ -32,14 +32,20 @@ parameters = {
parameter = { ident ~ ":" ~ typ } parameter = { ident ~ ":" ~ typ }
// Operators // Operators
infix = _{ add | subtract | multiply | divide | not_equal | equal | modulo } infix = _{ arithmetic_operator | logical_operator }
add = { "+" }
subtract = { "-" } arithmetic_operator = _{ add | subtract | multiply | divide | not_equal | equal | modulo }
multiply = { "*" } add = { "+" }
divide = { "/" } subtract = { "-" }
modulo = { "%" } multiply = { "*" }
equal = { "==" } divide = { "/" }
not_equal = { "!=" } modulo = { "%" }
equal = { "==" }
not_equal = { "!=" }
logical_operator = _{ and | or }
and = { "&&" }
or = { "||" }
prefix = _{ not } prefix = _{ not }
not = { "!" } not = { "!" }
@ -49,6 +55,7 @@ expr = { prefix? ~ atom ~ (infix ~ prefix? ~ atom)* }
atom = _{ call | if_expr | block | literal | ident | "(" ~ expr ~ ")" } atom = _{ call | if_expr | block | literal | ident | "(" ~ expr ~ ")" }
block = { "{" ~ statement* ~ expr? ~ "}" } block = { "{" ~ statement* ~ expr? ~ "}" }
if_expr = { "if" ~ expr ~ block ~ "else" ~ (block | if_expr) } if_expr = { "if" ~ expr ~ block ~ "else" ~ (block | if_expr) }
//tuple = { "(" ~ (expr ~ ",")+ ~ expr ~ ")" }
ident = @{ (ASCII_ALPHANUMERIC | "_")+ } ident = @{ (ASCII_ALPHANUMERIC | "_")+ }
typ = _{ ident } typ = _{ ident }

View file

@ -6,8 +6,8 @@ use pest::iterators::Pair;
use pest::pratt_parser::PrattParser; use pest::pratt_parser::PrattParser;
use pest::Parser; use pest::Parser;
use crate::ast::untyped::module::Module; use crate::ast::Module;
use crate::ast::untyped::*; use crate::ast::*;
use crate::ast::{Import, ModulePath}; use crate::ast::{Import, ModulePath};
use crate::typing::Type; use crate::typing::Type;
@ -23,6 +23,9 @@ lazy_static::lazy_static! {
// Precedence is defined lowest to highest // Precedence is defined lowest to highest
PrattParser::new() PrattParser::new()
.op(Op::infix(and, Left))
.op(Op::infix(or, Left))
.op(Op::prefix(not))
.op(Op::infix(equal, Left) | Op::infix(not_equal, Left)) .op(Op::infix(equal, Left) | Op::infix(not_equal, Left))
.op(Op::infix(add, Left) | Op::infix(subtract, Left)) .op(Op::infix(add, Left) | Op::infix(subtract, Left))
.op(Op::infix(modulo, Left)) .op(Op::infix(modulo, Left))
@ -57,7 +60,9 @@ pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
match pair.as_rule() { match pair.as_rule() {
Rule::definition => { Rule::definition => {
let def = parse_definition(pair.into_inner().next().unwrap()); let def = parse_definition(pair.into_inner().next().unwrap());
module.definitions.push(def); match def {
Definition::FunctionDefinition(func) => module.functions.push(func),
}
} }
Rule::use_statement => { Rule::use_statement => {
let path = parse_import(pair.into_inner().next().unwrap()); let path = parse_import(pair.into_inner().next().unwrap());
@ -82,7 +87,11 @@ fn parse_block(pair: Pair<Rule>) -> Block {
} }
} }
Block { statements, value } Block {
statements,
value,
typ: Type::Undefined,
}
} }
fn parse_statement(pair: Pair<Rule>) -> Statement { fn parse_statement(pair: Pair<Rule>) -> Statement {
@ -92,13 +101,13 @@ fn parse_statement(pair: Pair<Rule>) -> Statement {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string(); let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap()); let expr = parse_expression(pairs.next().unwrap());
Statement::AssignStatement(identifier, expr) Statement::AssignStatement(identifier, Box::new(expr))
} }
Rule::declare_statement => { Rule::declare_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string(); let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap()); let expr = parse_expression(pairs.next().unwrap());
Statement::DeclareStatement(identifier, expr) Statement::DeclareStatement(identifier, Box::new(expr))
} }
Rule::return_statement => { Rule::return_statement => {
let expr = if let Some(pair) = pair.into_inner().next() { let expr = if let Some(pair) = pair.into_inner().next() {
@ -110,17 +119,20 @@ fn parse_statement(pair: Pair<Rule>) -> Statement {
} }
Rule::call_statement => { Rule::call_statement => {
let call = parse_call(pair.into_inner().next().unwrap()); let call = parse_call(pair.into_inner().next().unwrap());
Statement::CallStatement(call) Statement::CallStatement(Box::new(call))
} }
Rule::use_statement => { Rule::use_statement => {
let import = parse_import(pair.into_inner().next().unwrap()); let import = parse_import(pair.into_inner().next().unwrap());
Statement::UseStatement(import) Statement::UseStatement(Box::new(import))
} }
Rule::if_statement => { Rule::if_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap()); let condition = parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap()); let block = parse_block(pairs.next().unwrap());
Statement::IfStatement(condition, block) if pairs.next().is_some() {
todo!("implement if-statements with else branch (and else if)")
}
Statement::IfStatement(Box::new(condition), Box::new(block))
} }
Rule::while_statement => { Rule::while_statement => {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
@ -132,8 +144,6 @@ fn parse_statement(pair: Pair<Rule>) -> Statement {
} }
} }
type ImportPath = ModulePath;
fn parse_import(pair: Pair<Rule>) -> Import { fn parse_import(pair: Pair<Rule>) -> Import {
Import(pair.as_str().to_string()) Import(pair.as_str().to_string())
} }
@ -141,16 +151,21 @@ fn parse_import(pair: Pair<Rule>) -> Import {
fn parse_call(pair: Pair<Rule>) -> Call { fn parse_call(pair: Pair<Rule>) -> Call {
let mut pairs = pair.into_inner(); let mut pairs = pair.into_inner();
// TODO: support calls on more than identifiers (needs grammar change) // TODO: support calls on more than identifiers (needs grammar change)
let callee = Expr::Identifier(pairs.next().unwrap().as_str().to_string()); let callee = Expr::Identifier {
name: pairs.next().unwrap().as_str().to_string(),
typ: Type::Undefined,
};
let args: Vec<Expr> = pairs let args: Vec<Expr> = pairs
.next() .next()
.unwrap() .unwrap()
.into_inner() .into_inner()
.map(parse_expression) .map(parse_expression)
.collect(); .collect();
Call { Call {
callee: Box::new(callee), callee: Box::new(callee),
args, args,
typ: Type::Undefined,
} }
} }
@ -170,7 +185,10 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
.unwrap(), .unwrap(),
), ),
Rule::expr => parse_expression(primary), Rule::expr => parse_expression(primary),
Rule::ident => Expr::Identifier(primary.as_str().to_string()), Rule::ident => Expr::Identifier {
name: primary.as_str().to_string(),
typ: Type::Undefined,
},
Rule::call => Expr::Call(Box::new(parse_call(primary))), Rule::call => Expr::Call(Box::new(parse_call(primary))),
Rule::block => Expr::Block(Box::new(parse_block(primary))), Rule::block => Expr::Block(Box::new(parse_block(primary))),
Rule::if_expr => { Rule::if_expr => {
@ -178,11 +196,12 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
let condition = parse_expression(pairs.next().unwrap()); let condition = parse_expression(pairs.next().unwrap());
let true_block = parse_block(pairs.next().unwrap()); let true_block = parse_block(pairs.next().unwrap());
let else_value = parse_expression(pairs.next().unwrap()); let else_value = parse_expression(pairs.next().unwrap());
Expr::IfExpr( Expr::IfExpr {
Box::new(condition), cond: Box::new(condition),
Box::new(true_block), then_body: Box::new(true_block),
Box::new(else_value), else_body: Box::new(else_value),
) typ: Type::Undefined,
}
} }
Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() { Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() {
"true" => true, "true" => true,
@ -203,9 +222,26 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
Rule::modulo => BinaryOperator::Modulo, Rule::modulo => BinaryOperator::Modulo,
Rule::equal => BinaryOperator::Equal, Rule::equal => BinaryOperator::Equal,
Rule::not_equal => BinaryOperator::NotEqual, Rule::not_equal => BinaryOperator::NotEqual,
Rule::and => BinaryOperator::And,
Rule::or => BinaryOperator::Or,
_ => unreachable!(), _ => unreachable!(),
}; };
Expr::BinaryExpression(Box::new(lhs), operator, Box::new(rhs)) Expr::BinaryExpression {
lhs: Box::new(lhs),
op: operator,
rhs: Box::new(rhs),
typ: Type::Undefined,
}
})
.map_prefix(|op, inner| {
let operator = match op.as_rule() {
Rule::not => UnaryOperator::Not,
_ => unreachable!(),
};
Expr::UnaryExpression {
op: operator,
inner: Box::new(inner),
}
}) })
.parse(pairs) .parse(pairs)
} }
@ -247,7 +283,7 @@ fn parse_definition(pair: Pair<Rule>) -> Definition {
parameters, parameters,
return_type, return_type,
body, body,
line_col, location: Location { line_col },
}) })
} }
_ => panic!("unexpected node for definition: {:?}", pair.as_rule()), _ => panic!("unexpected node for definition: {:?}", pair.as_rule()),

View file

@ -1,12 +1,8 @@
#[test] #[test]
fn test_addition_function() { fn test_addition_function() {
use crate::ast::{expr::Expr, *};
use crate::parsing::backend::pest::parse_as_module; use crate::parsing::backend::pest::parse_as_module;
use crate::{ use crate::typing::Type;
ast::untyped::module::Module,
ast::untyped::*,
ast::ModulePath,
typing::Type,
};
let source = "fn add(a: int, b: int) int { a + b }"; let source = "fn add(a: int, b: int) int { a + b }";
let path = ModulePath::from("test"); let path = ModulePath::from("test");
@ -15,7 +11,7 @@ fn test_addition_function() {
let expected_module = Module { let expected_module = Module {
file: None, file: None,
imports: vec![], imports: vec![],
definitions: vec![Definition::FunctionDefinition(FunctionDefinition { functions: vec![FunctionDefinition {
name: Identifier::from("add"), name: Identifier::from("add"),
parameters: vec![ parameters: vec![
Parameter { Parameter {
@ -30,14 +26,22 @@ fn test_addition_function() {
return_type: Some(Type::Int), return_type: Some(Type::Int),
body: Box::new(Block { body: Box::new(Block {
statements: vec![], statements: vec![],
value: Some(Expr::BinaryExpression( value: Some(Expr::BinaryExpression {
Box::new(Expr::Identifier(Identifier::from("a"))), lhs: Box::new(Expr::Identifier {
BinaryOperator::Add, name: Identifier::from("a"),
Box::new(Expr::Identifier(Identifier::from("b"))), typ: Type::Undefined,
)), }),
op: BinaryOperator::Add,
rhs: Box::new(Expr::Identifier {
name: Identifier::from("b"),
typ: Type::Undefined,
}),
typ: Type::Undefined,
}),
typ: Type::Undefined,
}), }),
line_col: (1, 1), location: Location { line_col: (1, 1) },
})], }],
path, path,
}; };

View file

@ -1,4 +1,6 @@
use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypeContext}; use crate::typing::{BinaryOperator, Identifier, ModulePath, Type, TypingContext};
use super::UnaryOperator;
#[derive(Debug)] #[derive(Debug)]
pub struct TypeError { pub struct TypeError {
@ -38,7 +40,7 @@ impl TypeError {
} }
impl TypeErrorBuilder { impl TypeErrorBuilder {
pub fn context(mut self, ctx: &TypeContext) -> Self { pub fn context(mut self, ctx: &TypingContext) -> Self {
self.file = ctx.file.clone(); self.file = ctx.file.clone();
self.module = Some(ctx.module.clone()); self.module = Some(ctx.module.clone());
self.function = ctx.function.clone(); self.function = ctx.function.clone();
@ -89,4 +91,8 @@ pub enum TypeErrorKind {
WrongFunctionArguments, WrongFunctionArguments,
ConditionIsNotBool, ConditionIsNotBool,
IfElseMismatch, IfElseMismatch,
InvalidUnaryOperator {
operator: UnaryOperator,
inner: Type,
},
} }

View file

@ -1,47 +1,88 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display;
use crate::ast::untyped::*;
use crate::ast::untyped::module::Module;
use crate::ast::ModulePath; use crate::ast::ModulePath;
use crate::ast::*;
mod error; mod error;
use crate::typing::error::{TypeError, TypeErrorKind}; use crate::typing::error::{TypeError, TypeErrorKind};
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Type { pub enum Type {
/// Not a real type, used for parsing pass
Undefined,
Bool, Bool,
Int, Int,
Float, Float,
Unit, Unit,
Str, Str,
Function {
params: Vec<Type>,
returns: Box<Type>,
},
Custom(Identifier), Custom(Identifier),
} }
impl Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Undefined => f.write_str("UNDEFINED"),
Type::Bool => f.write_str("Bool"),
Type::Int => f.write_str("Int"),
Type::Float => f.write_str("Float"),
Type::Unit => f.write_str("Unit"),
Type::Str => f.write_str("Str"),
Type::Custom(identifier) => f.write_str(identifier),
Type::Function { params, returns } => {
f.write_str("Fn(")?;
for param in params {
f.write_fmt(format_args!("{}, ", param))?;
}
f.write_str(") -> ")?;
f.write_fmt(format_args!("{}", returns))
}
}
}
}
impl From<&str> for Type { impl From<&str> for Type {
fn from(value: &str) -> Self { fn from(value: &str) -> Self {
match value { match value {
"int" => Type::Int, "int" => Type::Int,
"float" => Type::Float, "float" => Type::Float,
"bool" => Type::Bool,
_ => Type::Custom(Identifier::from(value)), _ => Type::Custom(Identifier::from(value)),
} }
} }
} }
impl untyped::FunctionDefinition { #[derive(Debug, PartialEq, Clone)]
fn signature(&self) -> (Vec<Type>, Type) { pub struct Signature(Vec<Type>, Type);
impl Into<Type> for Signature {
fn into(self) -> Type {
Type::Function {
params: self.0,
returns: Box::new(self.1),
}
}
}
impl FunctionDefinition {
fn signature(&self) -> Signature {
let return_type = self.return_type.clone().unwrap_or(Type::Unit); let return_type = self.return_type.clone().unwrap_or(Type::Unit);
let params_types = self.parameters.iter().map(|p| p.typ.clone()).collect(); let params_types = self.parameters.iter().map(|p| p.typ.clone()).collect();
(params_types, return_type) Signature(params_types, return_type)
} }
} }
impl Module { impl Module {
pub fn type_check(&self) -> Result<(), TypeError> { pub fn type_check(&mut self) -> Result<(), TypeError> {
let mut ctx = TypeContext::new(self.path.clone()); let mut ctx = TypingContext::new(self.path.clone());
ctx.file = self.file.clone(); ctx.file = self.file.clone();
// Register all function signatures // Register all function signatures
for Definition::FunctionDefinition(func) in &self.definitions { for func in &self.functions {
if let Some(_previous) = ctx.functions.insert(func.name.clone(), func.signature()) { if let Some(_previous) = ctx.functions.insert(func.name.clone(), func.signature()) {
todo!("handle redefinition of function or identical function names across different files"); todo!("handle redefinition of function or identical function names across different files");
} }
@ -49,8 +90,8 @@ impl Module {
// TODO: add signatures of imported functions (even if they have not been checked) // TODO: add signatures of imported functions (even if they have not been checked)
// Type-check the function bodies // Type-check the function bodies and complete all type placeholders
for Definition::FunctionDefinition(func) in &self.definitions { for func in &mut self.functions {
func.typ(&mut ctx)?; func.typ(&mut ctx)?;
ctx.variables.clear(); ctx.variables.clear();
} }
@ -59,18 +100,20 @@ impl Module {
} }
} }
pub struct TypeContext { pub struct TypingContext {
pub file: Option<std::path::PathBuf>, pub file: Option<std::path::PathBuf>,
pub module: ModulePath, pub module: ModulePath,
pub function: Option<Identifier>, pub function: Option<Identifier>,
pub functions: HashMap<Identifier, (Vec<Type>, Type)>, pub functions: HashMap<Identifier, Signature>,
pub variables: HashMap<Identifier, Type>, pub variables: HashMap<Identifier, Type>,
} }
impl TypeContext { impl TypingContext {
pub fn new(path: ModulePath) -> Self { pub fn new(path: ModulePath) -> Self {
let builtin_functions = let builtin_functions = HashMap::from([(
HashMap::from([(String::from("println"), (vec![Type::Str], Type::Unit))]); String::from("println"),
Signature(vec![Type::Str], Type::Unit),
)]);
Self { Self {
file: None, file: None,
@ -84,70 +127,72 @@ impl TypeContext {
/// Trait for nodes which have a deducible type. /// Trait for nodes which have a deducible type.
pub trait TypeCheck { pub trait TypeCheck {
/// Try to resolve the type of the node. /// Try to resolve the type of the node and complete its type placeholders.
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError>; fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError>;
} }
impl TypeCheck for FunctionDefinition { impl TypeCheck for FunctionDefinition {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
ctx.function = Some(self.name.clone()); ctx.function = Some(self.name.clone());
for param in &self.parameters { for param in &self.parameters {
// XXX: Parameter types should be checked
// when they are not builtin
ctx.variables.insert(param.name.clone(), param.typ.clone()); ctx.variables.insert(param.name.clone(), param.typ.clone());
} }
let body_type = &self.body.typ(ctx)?; let body_type = self.body.typ(ctx)?;
// If the return type is not specified, it is unit. // If the return type is not specified, it is unit.
let func_return_type = match &self.return_type { if self.return_type.is_none() {
Some(typ) => typ, self.return_type = Some(Type::Unit)
None => &Type::Unit, }
};
// Check coherence with the body's type. // Check coherence with the body's type.
if *func_return_type != *body_type { if *self.return_type.as_ref().unwrap() != body_type {
return Err(TypeError::builder() return Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType { .kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
block_type: body_type.clone(), block_type: body_type.clone(),
function_type: func_return_type.clone(), function_type: self.return_type.as_ref().unwrap().clone(),
}) })
.build()); .build());
} }
// Check coherence with return statements. // Check coherence with return statements.
for statement in &self.body.statements {
for statement in &mut self.body.statements {
if let Statement::ReturnStatement(value) = statement { if let Statement::ReturnStatement(value) = statement {
let ret_type = match value { let ret_type = match value {
Some(expr) => expr.typ(ctx)?, Some(expr) => expr.typ(ctx)?,
None => Type::Unit, None => Type::Unit,
}; };
if ret_type != *func_return_type { if ret_type != *self.return_type.as_ref().unwrap() {
return Err(TypeError::builder() return Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType { .kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
function_type: func_return_type.clone(), function_type: self.return_type.as_ref().unwrap().clone(),
return_type: ret_type, return_type: ret_type.clone(),
}) })
.build()); .build());
} }
} }
} }
Ok(func_return_type.clone()) Ok(self.return_type.clone().unwrap())
} }
} }
impl TypeCheck for Block { impl TypeCheck for Block {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
let mut return_typ: Option<Type> = None; let mut return_typ: Option<Type> = None;
// Check declarations and assignments. // Check declarations and assignments.
for statement in &self.statements { for statement in &mut self.statements {
match statement { match statement {
Statement::DeclareStatement(ident, expr) => { Statement::DeclareStatement(ident, expr) => {
let typ = expr.typ(ctx)?; let typ = expr.typ(ctx)?;
if let Some(_typ) = ctx.variables.insert(ident.clone(), typ) { if let Some(_typ) = ctx.variables.insert(ident.clone(), typ.clone()) {
// TODO: Shadowing? (illegal for now) // TODO: Shadowing? (illegal for now)
return Err(TypeError::builder() return Err(TypeError::builder()
.context(ctx) .context(ctx)
@ -159,9 +204,9 @@ impl TypeCheck for Block {
let rhs_typ = expr.typ(ctx)?; let rhs_typ = expr.typ(ctx)?;
let Some(lhs_typ) = ctx.variables.get(ident) else { let Some(lhs_typ) = ctx.variables.get(ident) else {
return Err(TypeError::builder() return Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::AssignUndeclared) .kind(TypeErrorKind::AssignUndeclared)
.build()); .build());
}; };
// Ensure same type on both sides. // Ensure same type on both sides.
@ -189,7 +234,7 @@ impl TypeCheck for Block {
.build()); .build());
} }
} else { } else {
return_typ = Some(expr_typ); return_typ = Some(expr_typ.clone());
} }
} }
Statement::CallStatement(call) => { Statement::CallStatement(call) => {
@ -220,9 +265,11 @@ impl TypeCheck for Block {
} }
// Check if there is an expression at the end of the block. // Check if there is an expression at the end of the block.
if let Some(expr) = &self.value { if let Some(expr) = &mut self.value {
expr.typ(ctx) self.typ = expr.typ(ctx)?.clone();
Ok(self.typ.clone())
} else { } else {
self.typ = Type::Unit;
Ok(Type::Unit) Ok(Type::Unit)
} }
@ -234,29 +281,34 @@ impl TypeCheck for Block {
} }
impl TypeCheck for Call { impl TypeCheck for Call {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
match &*self.callee { match &mut *self.callee {
Expr::Identifier(ident) => { Expr::Identifier { name, typ } => {
let signature = match ctx.functions.get(ident) { let signature = match ctx.functions.get(name) {
Some(sgn) => sgn.clone(), Some(sgn) => sgn.clone(),
None => { None => {
return Err(TypeError::builder() return Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::UnknownFunctionCalled(ident.clone())) .kind(TypeErrorKind::UnknownFunctionCalled(name.clone()))
.build()) .build())
} }
}; };
let (params_types, func_type) = signature;
*typ = signature.clone().into();
let Signature(params_types, func_type) = signature;
self.typ = func_type.clone();
// Collect arg types. // Collect arg types.
let mut args_types: Vec<Type> = vec![]; let mut args_types: Vec<Type> = vec![];
for arg in &self.args { for arg in &mut self.args {
let typ = arg.typ(ctx)?; let arg_typ = arg.typ(ctx)?;
args_types.push(typ.clone()); args_types.push(arg_typ.clone());
} }
if args_types == *params_types { if args_types == *params_types {
Ok(func_type.clone()) Ok(self.typ.clone())
} else { } else {
Err(TypeError::builder() Err(TypeError::builder()
.context(ctx) .context(ctx)
@ -270,16 +322,17 @@ impl TypeCheck for Call {
} }
impl TypeCheck for Expr { impl TypeCheck for Expr {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> { fn typ(&mut self, ctx: &mut TypingContext) -> Result<Type, TypeError> {
match self { match self {
Expr::Identifier(identifier) => { Expr::Identifier { name, typ } => {
if let Some(typ) = ctx.variables.get(identifier) { if let Some(ty) = ctx.variables.get(name) {
*typ = ty.clone();
Ok(typ.clone()) Ok(typ.clone())
} else { } else {
Err(TypeError::builder() Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::UnknownIdentifier { .kind(TypeErrorKind::UnknownIdentifier {
identifier: identifier.clone(), identifier: name.clone(),
}) })
.build()) .build())
} }
@ -287,192 +340,107 @@ impl TypeCheck for Expr {
Expr::BooleanLiteral(_) => Ok(Type::Bool), Expr::BooleanLiteral(_) => Ok(Type::Bool),
Expr::IntegerLiteral(_) => Ok(Type::Int), Expr::IntegerLiteral(_) => Ok(Type::Int),
Expr::FloatLiteral(_) => Ok(Type::Float), Expr::FloatLiteral(_) => Ok(Type::Float),
Expr::BinaryExpression(lhs, op, rhs) => match op { Expr::UnaryExpression { op, inner } => {
BinaryOperator::Add let inner_type = &inner.typ(ctx)?;
| BinaryOperator::Sub match (&op, inner_type) {
| BinaryOperator::Mul (UnaryOperator::Not, Type::Bool) => Ok(Type::Bool),
| BinaryOperator::Div => { _ => Err(TypeError::builder()
let left_type = &lhs.typ(ctx)?; .context(ctx)
let right_type = &rhs.typ(ctx)?; .kind(TypeErrorKind::InvalidUnaryOperator {
match (left_type, right_type) { operator: *op,
(Type::Int, Type::Int) => Ok(Type::Int), inner: inner_type.clone(),
(Type::Float, Type::Float) => Ok(Type::Float), })
(_, _) => Err(TypeError::builder() .build()),
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: left_type.clone(),
rht: right_type.clone(),
})
.build()),
}
} }
BinaryOperator::Equal | BinaryOperator::NotEqual => { }
let lhs_type = lhs.typ(ctx)?; Expr::BinaryExpression { lhs, op, rhs, typ } => {
let rhs_type = rhs.typ(ctx)?; let ty = match op {
if lhs_type != rhs_type { BinaryOperator::Add
return Err(TypeError::builder() | BinaryOperator::Sub
.context(ctx) | BinaryOperator::Mul
.kind(TypeErrorKind::InvalidBinaryOperator { | BinaryOperator::Div
operator: op.clone(), | BinaryOperator::And
lht: lhs_type.clone(), | BinaryOperator::Or => {
rht: rhs_type.clone(), let left_type = &lhs.typ(ctx)?;
}) let right_type = &rhs.typ(ctx)?;
.build()); match (left_type, right_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
(Type::Float, Type::Float) => Ok(Type::Float),
(Type::Bool, Type::Bool) => Ok(Type::Bool),
(_, _) => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: left_type.clone(),
rht: right_type.clone(),
})
.build()),
}
} }
Ok(Type::Bool) BinaryOperator::Equal | BinaryOperator::NotEqual => {
} let lhs_type = lhs.typ(ctx)?;
BinaryOperator::Modulo => { let rhs_type = rhs.typ(ctx)?;
let lhs_type = lhs.typ(ctx)?; if lhs_type != rhs_type {
let rhs_type = lhs.typ(ctx)?; return Err(TypeError::builder()
match (&lhs_type, &rhs_type) { .context(ctx)
(Type::Int, Type::Int) => Ok(Type::Int), .kind(TypeErrorKind::InvalidBinaryOperator {
_ => Err(TypeError::builder() operator: op.clone(),
.context(ctx) lht: lhs_type.clone(),
.kind(TypeErrorKind::InvalidBinaryOperator { rht: rhs_type.clone(),
operator: op.clone(), })
lht: lhs_type.clone(), .build());
rht: rhs_type.clone(), }
}) Ok(Type::Bool)
.build()),
} }
} BinaryOperator::Modulo => {
}, let lhs_type = lhs.typ(ctx)?;
let rhs_type = lhs.typ(ctx)?;
match (&lhs_type, &rhs_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
_ => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: lhs_type.clone(),
rht: rhs_type.clone(),
})
.build()),
}
}
};
*typ = ty?;
Ok(typ.clone())
}
Expr::StringLiteral(_) => Ok(Type::Str), Expr::StringLiteral(_) => Ok(Type::Str),
Expr::UnitLiteral => Ok(Type::Unit), Expr::UnitLiteral => Ok(Type::Unit),
Expr::Call(call) => call.typ(ctx), Expr::Call(call) => call.typ(ctx),
Expr::Block(block) => block.typ(ctx), Expr::Block(block) => block.typ(ctx),
Expr::IfExpr(cond, true_block, else_value) => { Expr::IfExpr {
cond,
then_body,
else_body,
typ,
} => {
if cond.typ(ctx)? != Type::Bool { if cond.typ(ctx)? != Type::Bool {
Err(TypeError::builder() Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool) .kind(TypeErrorKind::ConditionIsNotBool)
.build()) .build())
} else { } else {
let true_block_type = true_block.typ(ctx)?; let then_body_type = then_body.typ(ctx)?;
let else_type = else_value.typ(ctx)?; let else_type = else_body.typ(ctx)?;
if true_block_type != else_type { if then_body_type != else_type {
Err(TypeError::builder() Err(TypeError::builder()
.context(ctx) .context(ctx)
.kind(TypeErrorKind::IfElseMismatch) .kind(TypeErrorKind::IfElseMismatch)
.build()) .build())
} else { } else {
Ok(true_block_type.clone()) // XXX: opt: return ref to avoid cloning
*typ = then_body_type.clone();
Ok(then_body_type)
} }
} }
} }
} }
} }
} }
struct Typed<T: TypeCheck> {
inner: T,
typ: Type,
}
trait IntoTyped<T: TypeCheck> {
fn into_typed(self: Self, ctx: &mut TypeContext) -> Result<Typed<T>, TypeError>;
}
impl IntoTyped<Block> for Block {
fn into_typed(self: Block, ctx: &mut TypeContext) -> Result<Typed<Block>, TypeError> {
let mut return_typ: Option<Type> = None;
// Check declarations and assignments.
for statement in &self.statements {
match statement {
Statement::DeclareStatement(ident, expr) => {
let typ = expr.typ(ctx)?;
if let Some(_typ) = ctx.variables.insert(ident.clone(), typ) {
// XXX: Shadowing? (illegal for now)
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::VariableRedeclaration)
.build());
}
}
Statement::AssignStatement(ident, expr) => {
let rhs_typ = expr.typ(ctx)?;
let Some(lhs_typ) = ctx.variables.get(ident) else {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::AssignUndeclared)
.build());
};
// Ensure same type on both sides.
if rhs_typ != *lhs_typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::AssignmentMismatch {
lht: lhs_typ.clone(),
rht: rhs_typ.clone(),
})
.build());
}
}
Statement::ReturnStatement(maybe_expr) => {
let expr_typ = if let Some(expr) = maybe_expr {
expr.typ(ctx)?
} else {
Type::Unit
};
if let Some(typ) = &return_typ {
if expr_typ != *typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnStatementsMismatch)
.build());
}
} else {
return_typ = Some(expr_typ);
}
}
Statement::CallStatement(call) => {
call.typ(ctx)?;
}
Statement::UseStatement(_path) => {
// TODO: import the signatures (and types)
}
Statement::IfStatement(cond, block) => {
if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
}
block.typ(ctx)?;
}
Statement::WhileStatement(cond, block) => {
if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
}
block.typ(ctx)?;
}
}
}
// Check if there is an expression at the end of the block.
let typ = if let Some(expr) = &self.value {
expr.typ(ctx)?
} else {
Type::Unit
};
Ok(Typed { inner: self, typ })
// TODO/FIXME: find a way to return `return_typ` so that the
// top-level block (the function) can check if this return type
// (and eventually those from other block) matches the type of
// the function.
}
}