refactor typing and compilation

This commit is contained in:
Romain Paquet 2024-06-25 11:02:05 +02:00
parent e8deab19cc
commit ba838292f6
8 changed files with 175 additions and 131 deletions

View file

@ -1,4 +1,5 @@
use crate::ast::*; use crate::ast::*;
use crate::typing::Type;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Expr { pub enum Expr {
@ -32,3 +33,39 @@ pub enum Expr {
FloatLiteral(f64), FloatLiteral(f64),
StringLiteral(String), StringLiteral(String),
} }
impl Block {
#[inline]
pub fn ty(&self) -> Type {
// XXX: Cloning may be expensive -> TypeId?
self.typ.clone()
}
}
impl Expr {
pub fn ty(&self) -> Type {
match self {
Expr::BinaryExpression {
lhs: _,
op: _,
rhs: _,
typ,
} => typ.clone(),
Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here
Expr::Identifier { name: _, typ } => typ.clone(),
Expr::Call(call) => call.typ.clone(),
Expr::Block(block) => block.typ.clone(),
Expr::IfExpr {
cond: _,
then_body: _,
else_body: _,
typ,
} => typ.clone(),
Expr::UnitLiteral => Type::Unit,
Expr::BooleanLiteral(_) => Type::Bool,
Expr::IntegerLiteral(_) => Type::Int,
Expr::FloatLiteral(_) => Type::Float,
Expr::StringLiteral(_) => Type::Str,
}
}
}

View file

@ -1,11 +1,10 @@
use crate::typing::Type;
use std::path::Path;
pub mod expr; pub mod expr;
pub mod typed;
pub use expr::Expr; pub use expr::Expr;
use crate::typing::Type;
use std::path::Path;
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum BinaryOperator { pub enum BinaryOperator {
// Logic // Logic
@ -38,6 +37,19 @@ pub struct ModulePath {
components: Vec<String>, components: Vec<String>,
} }
impl ModulePath {
pub fn concat(lhs: &ModulePath, rhs: &ModulePath) -> ModulePath {
ModulePath {
components: Vec::from_iter(
lhs.components
.iter()
.chain(rhs.components.iter())
.map(Clone::clone),
),
}
}
}
impl std::fmt::Display for ModulePath { impl std::fmt::Display for ModulePath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.components.join("::"))) f.write_fmt(format_args!("{}", self.components.join("::")))
@ -136,6 +148,13 @@ impl Module {
..Default::default() ..Default::default()
} }
} }
pub fn full_func_path(&self, func: usize) -> ModulePath {
ModulePath::concat(
&self.path,
&ModulePath::from(self.functions[func].name.as_str()),
)
}
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]

View file

@ -1,27 +0,0 @@
use crate::ast::{
typed::{Block, Call},
BinaryOperator, UnaryOperator,
};
use crate::typing::Type;
impl Expr {
pub fn typ(&self) -> &Type {
match self {
Expr::BinaryExpression { lhs, op, rhs, typ } => &typ.unwrap(),
Expr::UnaryExpression { op, inner } => inner.typ(), // XXX: problems will arise here
Expr::Variable { name, typ } => &typ.unwrap(),
Expr::Call { call, typ } => &typ.unwrap(),
Expr::Block { block, typ } => &typ.unwrap(),
Expr::IfExpr {
cond,
then_body,
else_body,
typ,
} => &typ.unwrap(),
Expr::UnitLiteral => &Type::Unit,
Expr::BooleanLiteral(_) => &Type::Bool,
Expr::IntegerLiteral(_) => &Type::Int,
Expr::FloatLiteral(_) => &Type::Float,
Expr::StringLiteral(_) => &Type::Str,
}
}
}

View file

@ -1,37 +0,0 @@
use crate::ast::*;
impl Block {
#[inline]
pub fn ty(&self) -> Type {
// XXX: Cloning may be expensive -> TypeId?
self.typ.clone()
}
}
impl Expr {
pub fn ty(&self) -> Type {
match self {
Expr::BinaryExpression {
lhs: _,
op: _,
rhs: _,
typ,
} => typ.clone(),
Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here
Expr::Identifier { name: _, typ } => typ.clone(),
Expr::Call(call) => call.typ.clone(),
Expr::Block(block) => block.typ.clone(),
Expr::IfExpr {
cond: _,
then_body: _,
else_body: _,
typ,
} => typ.clone(),
Expr::UnitLiteral => Type::Unit,
Expr::BooleanLiteral(_) => Type::Bool,
Expr::IntegerLiteral(_) => Type::Int,
Expr::FloatLiteral(_) => Type::Float,
Expr::StringLiteral(_) => Type::Str,
}
}
}

View file

@ -4,12 +4,12 @@ use crate::{
{expr::Expr, FunctionDefinition, Statement}, {expr::Expr, FunctionDefinition, Statement},
}, },
parsing, parsing,
typing::Type, typing::{CheckedModule, Type},
}; };
use cranelift::{codegen::ir::UserFuncName, prelude::*}; use cranelift::{codegen::ir::UserFuncName, prelude::*};
use cranelift_jit::{JITBuilder, JITModule}; use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module}; use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module};
use std::{collections::HashMap, ops::Deref}; use std::{collections::HashMap, fs, ops::Deref};
/// The basic JIT class. /// The basic JIT class.
pub struct JIT { pub struct JIT {
@ -27,6 +27,9 @@ pub struct JIT {
/// The module, with the jit backend, which manages the JIT'd functions. /// The module, with the jit backend, which manages the JIT'd functions.
module: JITModule, module: JITModule,
/// Whether to print CLIR during compilation
pub dump_clir: bool,
} }
impl Default for JIT { impl Default for JIT {
@ -47,29 +50,30 @@ impl Default for JIT {
let module = JITModule::new(builder); let module = JITModule::new(builder);
let mut ctx = module.make_context();
ctx.set_disasm(true);
Self { Self {
builder_context: FunctionBuilderContext::new(), builder_context: FunctionBuilderContext::new(),
ctx: module.make_context(), ctx,
data_desc: DataDescription::new(), data_desc: DataDescription::new(),
module, module,
dump_clir: false,
} }
} }
} }
impl JIT { impl JIT {
/// Compile source code into machine code. /// Compile source code into machine code.
pub fn compile(&mut self, input: &str, dump_clir: bool) -> Result<*const u8, String> { pub fn compile(&mut self, input: &str, namespace: ModulePath) -> Result<*const u8, String> {
// Parse the source code into an AST // Parse the source code into an AST
let Ok(mut ast) = parsing::parse_as_module(input, ModulePath::from("globalmodule")) else { let ast = parsing::parse_as_module(input, namespace)
return Err("Parsing error".to_string()); .map_err(|x| format!("Parsing error: {x}"))?
}; .type_check()
.map_err(|x| format!("Typing error: {x}"))?;
if let Err(e) = ast.type_check() {
return Err(e.to_string());
};
// Translate the AST into Cranelift IR // Translate the AST into Cranelift IR
self.translate(&ast, dump_clir)?; self.translate(&ast)?;
// Finalize the functions which we just defined, which resolves any // Finalize the functions which we just defined, which resolves any
// outstanding relocations (patching in addresses, now that they're // outstanding relocations (patching in addresses, now that they're
@ -85,8 +89,20 @@ impl JIT {
} }
} }
pub fn compile_file(&mut self, path: &str) -> Result<*const u8, String> {
self.compile(
fs::read_to_string(path)
.map_err(|x| format!("Cannot open {}: {}", path, x))?
.as_str(),
AsRef::<std::path::Path>::as_ref(path).into(),
)
}
/// Translate language AST into Cranelift IR. /// Translate language AST into Cranelift IR.
fn translate(&mut self, ast: &ast::Module, dump_clir: bool) -> Result<(), String> { fn translate(&mut self, ast: &CheckedModule) -> Result<(), String> {
// Dump contract-holding wrapper type
let ast = &ast.0;
let mut signatures: Vec<Signature> = Vec::with_capacity(ast.functions.len()); let mut signatures: Vec<Signature> = Vec::with_capacity(ast.functions.len());
let mut func_ids: Vec<FuncId> = Vec::with_capacity(ast.functions.len()); let mut func_ids: Vec<FuncId> = Vec::with_capacity(ast.functions.len());
@ -96,7 +112,7 @@ impl JIT {
let mut sig = self.module.make_signature(); let mut sig = self.module.make_signature();
for param in &func.parameters { for param in &func.parameters {
assert!(param.typ != Type::Unit); assert_ne!(param.typ, Type::Unit);
sig.params.append(&mut Vec::from(&param.typ)); sig.params.append(&mut Vec::from(&param.typ));
} }
@ -124,9 +140,9 @@ impl JIT {
.define_function(func_ids[i], &mut self.ctx) .define_function(func_ids[i], &mut self.ctx)
.unwrap(); .unwrap();
if dump_clir { if self.dump_clir {
println!("// {}", func.name); println!("// {}", ast.full_func_path(i));
println!("{}", self.ctx.func.display()); println!("{}", self.ctx.func);
} }
self.module.clear_context(&mut self.ctx); self.module.clear_context(&mut self.ctx);
@ -155,25 +171,28 @@ impl JIT {
builder.seal_block(entry_block); builder.seal_block(entry_block);
// Walk the AST and declare all implicitly-declared variables. // Walk the AST and declare all implicitly-declared variables.
let mut variables = HashMap::<String, Variable>::default(); // TODO: actually do this let variables = HashMap::<String, Variable>::default(); // TODO: actually do this
// Add a variable for each parameter.
let param_values: Box<[Value]> = builder.block_params(entry_block).into();
assert!(param_values.len() == function.parameters.len());
for (i, param) in function.parameters.iter().enumerate() {
let var = Variable::from_u32(variables.len() as u32);
variables.insert(param.name.clone(), var);
let value = param_values[i];
builder.declare_var(var, param.typ.clone().into());
builder.def_var(var, value);
}
// Now translate the statements of the function body.
let mut translator = FunctionTranslator { let mut translator = FunctionTranslator {
builder, builder,
variables, variables,
module: &mut self.module, module: &mut self.module,
data_desc: &mut self.data_desc,
}; };
// Add a variable for each parameter.
let param_values: Box<[Value]> = translator.builder.block_params(entry_block).into();
assert_eq!(param_values.len(), function.parameters.len());
for (i, param) in function.parameters.iter().enumerate() {
let var = Variable::from_u32(translator.variables.len() as u32);
translator.variables.insert(param.name.clone(), var);
let value = param_values[i];
let typ = translator.translate_type(&param.typ);
translator.builder.declare_var(var, typ);
translator.builder.def_var(var, value);
}
// Now translate the statements of the function body.
for stmt in &function.body.statements { for stmt in &function.body.statements {
translator.translate_statement(stmt); translator.translate_statement(stmt);
} }
@ -192,24 +211,6 @@ impl JIT {
} }
} }
impl From<crate::typing::Type> for types::Type {
fn from(value: crate::typing::Type) -> Self {
match value {
Type::Bool => types::I8,
Type::Int => types::I32,
Type::Float => types::F32,
Type::Unit => unreachable!(),
Type::Str => todo!(),
Type::Custom(_) => todo!(),
Type::Function {
params: _,
returns: _,
} => todo!(),
Type::Undefined => unreachable!(),
}
}
}
impl From<&Type> for Vec<AbiParam> { impl From<&Type> for Vec<AbiParam> {
fn from(value: &Type) -> Self { fn from(value: &Type) -> Self {
match value { match value {
@ -227,6 +228,7 @@ struct FunctionTranslator<'a> {
builder: FunctionBuilder<'a>, builder: FunctionBuilder<'a>,
variables: HashMap<String, Variable>, variables: HashMap<String, Variable>,
module: &'a mut JITModule, module: &'a mut JITModule,
data_desc: &'a mut DataDescription,
} }
impl<'a> FunctionTranslator<'a> { impl<'a> FunctionTranslator<'a> {
@ -245,7 +247,8 @@ impl<'a> FunctionTranslator<'a> {
Statement::DeclareStatement(name, expr) => { Statement::DeclareStatement(name, expr) => {
let value = self.translate_expr(expr); let value = self.translate_expr(expr);
let variable = Variable::from_u32(self.variables.len() as u32); let variable = Variable::from_u32(self.variables.len() as u32);
self.builder.declare_var(variable, expr.ty().into()); self.builder
.declare_var(variable, self.translate_type(&expr.ty()));
self.builder.def_var(variable, value); self.builder.def_var(variable, value);
self.variables.insert(name.clone(), variable); self.variables.insert(name.clone(), variable);
Some(value) Some(value)
@ -303,7 +306,17 @@ impl<'a> FunctionTranslator<'a> {
Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm), Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm),
Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm), Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm),
Expr::StringLiteral(_) => todo!(), Expr::StringLiteral(s) => {
let id = self.module.declare_anonymous_data(false, false).unwrap();
let bytes: Box<[u8]> = s.as_bytes().into();
self.data_desc.define(bytes);
self.module.define_data(id, self.data_desc).unwrap();
let gv = self.module.declare_data_in_func(id, self.builder.func);
self.data_desc.clear();
self.builder
.ins()
.global_value(self.module.isa().pointer_type(), gv)
}
Expr::BinaryExpression { Expr::BinaryExpression {
lhs, lhs,
@ -361,7 +374,7 @@ impl<'a> FunctionTranslator<'a> {
// so set up a parameter in the merge block, and we'll pass // so set up a parameter in the merge block, and we'll pass
// the return values to it from the branches. // the return values to it from the branches.
self.builder self.builder
.append_block_param(merge_block, typ.clone().into()); .append_block_param(merge_block, self.translate_type(typ));
// Test the if condition and conditionally branch. // Test the if condition and conditionally branch.
self.builder self.builder
@ -454,7 +467,6 @@ impl<'a> FunctionTranslator<'a> {
let args: Vec<Value> = call.args.iter().map(|a| self.translate_expr(a)).collect(); let args: Vec<Value> = call.args.iter().map(|a| self.translate_expr(a)).collect();
// TODO: handle the return value of the function
let call_inst = self.builder.ins().call(func_ref, &args); let call_inst = self.builder.ins().call(func_ref, &args);
let results = self.builder.inst_results(call_inst); let results = self.builder.inst_results(call_inst);
Some(results[0]) Some(results[0])
@ -462,4 +474,20 @@ impl<'a> FunctionTranslator<'a> {
_ => unimplemented!(), _ => unimplemented!(),
} }
} }
fn translate_type(&self, value: &crate::typing::Type) -> types::Type {
match value {
Type::Bool => types::I8,
Type::Int => types::I32,
Type::Float => types::F32,
Type::Unit => unreachable!(),
Type::Str => self.module.isa().pointer_type(),
Type::Custom(_) => todo!(),
Type::Function {
params: _,
returns: _,
} => todo!(),
Type::Undefined => unreachable!(),
}
}
} }

View file

@ -42,6 +42,13 @@ enum Commands {
#[arg(long)] #[arg(long)]
dump_clir: bool, dump_clir: bool,
}, },
Run {
/// Paths to the source files
files: Vec<String>,
#[arg(long)]
dump_clir: bool,
},
} }
fn parse(files: &Vec<String>) -> Vec<Module> { fn parse(files: &Vec<String>) -> Vec<Module> {
@ -55,7 +62,7 @@ fn parse(files: &Vec<String>) -> Vec<Module> {
} }
fn check(modules: &mut Vec<Module>) { fn check(modules: &mut Vec<Module>) {
for module in modules { while let Some(module) = modules.pop() {
if let Err(e) = module.type_check() { if let Err(e) = module.type_check() {
eprintln!("{}", e); eprintln!("{}", e);
return; return;
@ -85,12 +92,24 @@ fn main() {
} }
} }
} }
Commands::Compile { files, dump_clir } => { Commands::Compile { files, dump_clir } | Commands::Run { files, dump_clir } => {
let mut jit = jit::JIT::default(); let mut jit = jit::JIT::default();
jit.dump_clir = *dump_clir;
for file in files { for file in files {
match jit.compile(std::fs::read_to_string(file).unwrap().as_str(), *dump_clir) { match jit.compile_file(file) {
Err(e) => eprintln!("{}", e), Err(e) => eprintln!("{}", e),
Ok(_code) => println!("Compiled {}", file), Ok(code) => {
println!("Compiled {}", file);
if let Commands::Run { .. } = cli.command {
let ret = unsafe {
let code_fn: unsafe extern "sysv64" fn() -> i32 =
std::mem::transmute(code);
code_fn()
};
println!("Main returned {}", ret);
}
}
} }
} }
} }

View file

@ -4,10 +4,10 @@ use super::UnaryOperator;
#[derive(Debug)] #[derive(Debug)]
pub struct TypeError { pub struct TypeError {
file: Option<std::path::PathBuf>, pub file: Option<std::path::PathBuf>,
module: ModulePath, pub module: ModulePath,
function: Option<String>, pub function: Option<String>,
kind: TypeErrorKind, pub kind: TypeErrorKind,
} }
impl std::fmt::Display for TypeError { impl std::fmt::Display for TypeError {
@ -62,7 +62,7 @@ impl TypeErrorBuilder {
} }
} }
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
InvalidBinaryOperator { InvalidBinaryOperator {
operator: BinaryOperator, operator: BinaryOperator,

View file

@ -7,6 +7,9 @@ use crate::ast::*;
mod error; mod error;
use crate::typing::error::{TypeError, TypeErrorKind}; use crate::typing::error::{TypeError, TypeErrorKind};
#[cfg(test)]
mod tests;
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Type { pub enum Type {
/// Not a real type, used for parsing pass /// Not a real type, used for parsing pass
@ -76,8 +79,10 @@ impl FunctionDefinition {
} }
} }
pub struct CheckedModule(pub Module);
impl Module { impl Module {
pub fn type_check(&mut self) -> Result<(), TypeError> { pub fn type_check(mut self) -> Result<CheckedModule, TypeError> {
let mut ctx = TypingContext::new(self.path.clone()); let mut ctx = TypingContext::new(self.path.clone());
ctx.file = self.file.clone(); ctx.file = self.file.clone();
@ -96,7 +101,7 @@ impl Module {
ctx.variables.clear(); ctx.variables.clear();
} }
Ok(()) Ok(CheckedModule(self))
} }
} }