From ba838292f66922e7ac7b7269a5334f8c3fbfdb3e Mon Sep 17 00:00:00 2001 From: Romain Paquet Date: Tue, 25 Jun 2024 11:02:05 +0200 Subject: [PATCH] refactor typing and compilation --- src/ast/expr.rs | 37 ++++++++++++ src/ast/mod.rs | 27 +++++++-- src/ast/typed/expr.rs | 27 --------- src/ast/typed/mod.rs | 37 ------------ src/jit/mod.rs | 132 +++++++++++++++++++++++++----------------- src/main.rs | 27 +++++++-- src/typing/error.rs | 10 ++-- src/typing/mod.rs | 9 ++- 8 files changed, 175 insertions(+), 131 deletions(-) delete mode 100644 src/ast/typed/expr.rs delete mode 100644 src/ast/typed/mod.rs diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 5f776cb..90ffc48 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -1,4 +1,5 @@ use crate::ast::*; +use crate::typing::Type; #[derive(Debug, PartialEq)] pub enum Expr { @@ -32,3 +33,39 @@ pub enum Expr { FloatLiteral(f64), StringLiteral(String), } + +impl Block { + #[inline] + pub fn ty(&self) -> Type { + // XXX: Cloning may be expensive -> TypeId? + self.typ.clone() + } +} + +impl Expr { + pub fn ty(&self) -> Type { + match self { + Expr::BinaryExpression { + lhs: _, + op: _, + rhs: _, + typ, + } => typ.clone(), + Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here + Expr::Identifier { name: _, typ } => typ.clone(), + Expr::Call(call) => call.typ.clone(), + Expr::Block(block) => block.typ.clone(), + Expr::IfExpr { + cond: _, + then_body: _, + else_body: _, + typ, + } => typ.clone(), + Expr::UnitLiteral => Type::Unit, + Expr::BooleanLiteral(_) => Type::Bool, + Expr::IntegerLiteral(_) => Type::Int, + Expr::FloatLiteral(_) => Type::Float, + Expr::StringLiteral(_) => Type::Str, + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8348a29..3355c10 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,11 +1,10 @@ -use crate::typing::Type; -use std::path::Path; - pub mod expr; -pub mod typed; pub use expr::Expr; +use crate::typing::Type; +use std::path::Path; + #[derive(Debug, PartialEq, Clone)] pub enum BinaryOperator { // Logic @@ -38,6 +37,19 @@ pub struct ModulePath { components: Vec, } +impl ModulePath { + pub fn concat(lhs: &ModulePath, rhs: &ModulePath) -> ModulePath { + ModulePath { + components: Vec::from_iter( + lhs.components + .iter() + .chain(rhs.components.iter()) + .map(Clone::clone), + ), + } + } +} + impl std::fmt::Display for ModulePath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!("{}", self.components.join("::"))) @@ -136,6 +148,13 @@ impl Module { ..Default::default() } } + + pub fn full_func_path(&self, func: usize) -> ModulePath { + ModulePath::concat( + &self.path, + &ModulePath::from(self.functions[func].name.as_str()), + ) + } } #[derive(Debug, PartialEq)] diff --git a/src/ast/typed/expr.rs b/src/ast/typed/expr.rs deleted file mode 100644 index 736bc22..0000000 --- a/src/ast/typed/expr.rs +++ /dev/null @@ -1,27 +0,0 @@ -use crate::ast::{ - typed::{Block, Call}, - BinaryOperator, UnaryOperator, -}; -use crate::typing::Type; -impl Expr { - pub fn typ(&self) -> &Type { - match self { - Expr::BinaryExpression { lhs, op, rhs, typ } => &typ.unwrap(), - Expr::UnaryExpression { op, inner } => inner.typ(), // XXX: problems will arise here - Expr::Variable { name, typ } => &typ.unwrap(), - Expr::Call { call, typ } => &typ.unwrap(), - Expr::Block { block, typ } => &typ.unwrap(), - Expr::IfExpr { - cond, - then_body, - else_body, - typ, - } => &typ.unwrap(), - Expr::UnitLiteral => &Type::Unit, - Expr::BooleanLiteral(_) => &Type::Bool, - Expr::IntegerLiteral(_) => &Type::Int, - Expr::FloatLiteral(_) => &Type::Float, - Expr::StringLiteral(_) => &Type::Str, - } - } -} diff --git a/src/ast/typed/mod.rs b/src/ast/typed/mod.rs deleted file mode 100644 index 033129a..0000000 --- a/src/ast/typed/mod.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::ast::*; - -impl Block { - #[inline] - pub fn ty(&self) -> Type { - // XXX: Cloning may be expensive -> TypeId? - self.typ.clone() - } -} - -impl Expr { - pub fn ty(&self) -> Type { - match self { - Expr::BinaryExpression { - lhs: _, - op: _, - rhs: _, - typ, - } => typ.clone(), - Expr::UnaryExpression { op: _, inner } => inner.ty(), // XXX: problems will arise here - Expr::Identifier { name: _, typ } => typ.clone(), - Expr::Call(call) => call.typ.clone(), - Expr::Block(block) => block.typ.clone(), - Expr::IfExpr { - cond: _, - then_body: _, - else_body: _, - typ, - } => typ.clone(), - Expr::UnitLiteral => Type::Unit, - Expr::BooleanLiteral(_) => Type::Bool, - Expr::IntegerLiteral(_) => Type::Int, - Expr::FloatLiteral(_) => Type::Float, - Expr::StringLiteral(_) => Type::Str, - } - } -} diff --git a/src/jit/mod.rs b/src/jit/mod.rs index 0a9a9b4..eff74bc 100644 --- a/src/jit/mod.rs +++ b/src/jit/mod.rs @@ -4,12 +4,12 @@ use crate::{ {expr::Expr, FunctionDefinition, Statement}, }, parsing, - typing::Type, + typing::{CheckedModule, Type}, }; use cranelift::{codegen::ir::UserFuncName, prelude::*}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module}; -use std::{collections::HashMap, ops::Deref}; +use std::{collections::HashMap, fs, ops::Deref}; /// The basic JIT class. pub struct JIT { @@ -27,6 +27,9 @@ pub struct JIT { /// The module, with the jit backend, which manages the JIT'd functions. module: JITModule, + + /// Whether to print CLIR during compilation + pub dump_clir: bool, } impl Default for JIT { @@ -47,29 +50,30 @@ impl Default for JIT { let module = JITModule::new(builder); + let mut ctx = module.make_context(); + ctx.set_disasm(true); + Self { builder_context: FunctionBuilderContext::new(), - ctx: module.make_context(), + ctx, data_desc: DataDescription::new(), module, + dump_clir: false, } } } impl JIT { /// Compile source code into machine code. - pub fn compile(&mut self, input: &str, dump_clir: bool) -> Result<*const u8, String> { + pub fn compile(&mut self, input: &str, namespace: ModulePath) -> Result<*const u8, String> { // Parse the source code into an AST - let Ok(mut ast) = parsing::parse_as_module(input, ModulePath::from("globalmodule")) else { - return Err("Parsing error".to_string()); - }; - - if let Err(e) = ast.type_check() { - return Err(e.to_string()); - }; + let ast = parsing::parse_as_module(input, namespace) + .map_err(|x| format!("Parsing error: {x}"))? + .type_check() + .map_err(|x| format!("Typing error: {x}"))?; // Translate the AST into Cranelift IR - self.translate(&ast, dump_clir)?; + self.translate(&ast)?; // Finalize the functions which we just defined, which resolves any // outstanding relocations (patching in addresses, now that they're @@ -85,8 +89,20 @@ impl JIT { } } + pub fn compile_file(&mut self, path: &str) -> Result<*const u8, String> { + self.compile( + fs::read_to_string(path) + .map_err(|x| format!("Cannot open {}: {}", path, x))? + .as_str(), + AsRef::::as_ref(path).into(), + ) + } + /// Translate language AST into Cranelift IR. - fn translate(&mut self, ast: &ast::Module, dump_clir: bool) -> Result<(), String> { + fn translate(&mut self, ast: &CheckedModule) -> Result<(), String> { + // Dump contract-holding wrapper type + let ast = &ast.0; + let mut signatures: Vec = Vec::with_capacity(ast.functions.len()); let mut func_ids: Vec = Vec::with_capacity(ast.functions.len()); @@ -96,7 +112,7 @@ impl JIT { let mut sig = self.module.make_signature(); for param in &func.parameters { - assert!(param.typ != Type::Unit); + assert_ne!(param.typ, Type::Unit); sig.params.append(&mut Vec::from(¶m.typ)); } @@ -124,9 +140,9 @@ impl JIT { .define_function(func_ids[i], &mut self.ctx) .unwrap(); - if dump_clir { - println!("// {}", func.name); - println!("{}", self.ctx.func.display()); + if self.dump_clir { + println!("// {}", ast.full_func_path(i)); + println!("{}", self.ctx.func); } self.module.clear_context(&mut self.ctx); @@ -155,25 +171,28 @@ impl JIT { builder.seal_block(entry_block); // Walk the AST and declare all implicitly-declared variables. - let mut variables = HashMap::::default(); // TODO: actually do this + let variables = HashMap::::default(); // TODO: actually do this - // Add a variable for each parameter. - let param_values: Box<[Value]> = builder.block_params(entry_block).into(); - assert!(param_values.len() == function.parameters.len()); - for (i, param) in function.parameters.iter().enumerate() { - let var = Variable::from_u32(variables.len() as u32); - variables.insert(param.name.clone(), var); - let value = param_values[i]; - builder.declare_var(var, param.typ.clone().into()); - builder.def_var(var, value); - } - - // Now translate the statements of the function body. let mut translator = FunctionTranslator { builder, variables, module: &mut self.module, + data_desc: &mut self.data_desc, }; + + // Add a variable for each parameter. + let param_values: Box<[Value]> = translator.builder.block_params(entry_block).into(); + assert_eq!(param_values.len(), function.parameters.len()); + for (i, param) in function.parameters.iter().enumerate() { + let var = Variable::from_u32(translator.variables.len() as u32); + translator.variables.insert(param.name.clone(), var); + let value = param_values[i]; + let typ = translator.translate_type(¶m.typ); + translator.builder.declare_var(var, typ); + translator.builder.def_var(var, value); + } + + // Now translate the statements of the function body. for stmt in &function.body.statements { translator.translate_statement(stmt); } @@ -192,24 +211,6 @@ impl JIT { } } -impl From for types::Type { - fn from(value: crate::typing::Type) -> Self { - match value { - Type::Bool => types::I8, - Type::Int => types::I32, - Type::Float => types::F32, - Type::Unit => unreachable!(), - Type::Str => todo!(), - Type::Custom(_) => todo!(), - Type::Function { - params: _, - returns: _, - } => todo!(), - Type::Undefined => unreachable!(), - } - } -} - impl From<&Type> for Vec { fn from(value: &Type) -> Self { match value { @@ -227,6 +228,7 @@ struct FunctionTranslator<'a> { builder: FunctionBuilder<'a>, variables: HashMap, module: &'a mut JITModule, + data_desc: &'a mut DataDescription, } impl<'a> FunctionTranslator<'a> { @@ -245,7 +247,8 @@ impl<'a> FunctionTranslator<'a> { Statement::DeclareStatement(name, expr) => { let value = self.translate_expr(expr); let variable = Variable::from_u32(self.variables.len() as u32); - self.builder.declare_var(variable, expr.ty().into()); + self.builder + .declare_var(variable, self.translate_type(&expr.ty())); self.builder.def_var(variable, value); self.variables.insert(name.clone(), variable); Some(value) @@ -303,7 +306,17 @@ impl<'a> FunctionTranslator<'a> { Expr::IntegerLiteral(imm) => self.builder.ins().iconst(types::I32, *imm), Expr::FloatLiteral(imm) => self.builder.ins().f64const(*imm), - Expr::StringLiteral(_) => todo!(), + Expr::StringLiteral(s) => { + let id = self.module.declare_anonymous_data(false, false).unwrap(); + let bytes: Box<[u8]> = s.as_bytes().into(); + self.data_desc.define(bytes); + self.module.define_data(id, self.data_desc).unwrap(); + let gv = self.module.declare_data_in_func(id, self.builder.func); + self.data_desc.clear(); + self.builder + .ins() + .global_value(self.module.isa().pointer_type(), gv) + } Expr::BinaryExpression { lhs, @@ -361,7 +374,7 @@ impl<'a> FunctionTranslator<'a> { // so set up a parameter in the merge block, and we'll pass // the return values to it from the branches. self.builder - .append_block_param(merge_block, typ.clone().into()); + .append_block_param(merge_block, self.translate_type(typ)); // Test the if condition and conditionally branch. self.builder @@ -454,7 +467,6 @@ impl<'a> FunctionTranslator<'a> { let args: Vec = call.args.iter().map(|a| self.translate_expr(a)).collect(); - // TODO: handle the return value of the function let call_inst = self.builder.ins().call(func_ref, &args); let results = self.builder.inst_results(call_inst); Some(results[0]) @@ -462,4 +474,20 @@ impl<'a> FunctionTranslator<'a> { _ => unimplemented!(), } } + + fn translate_type(&self, value: &crate::typing::Type) -> types::Type { + match value { + Type::Bool => types::I8, + Type::Int => types::I32, + Type::Float => types::F32, + Type::Unit => unreachable!(), + Type::Str => self.module.isa().pointer_type(), + Type::Custom(_) => todo!(), + Type::Function { + params: _, + returns: _, + } => todo!(), + Type::Undefined => unreachable!(), + } + } } diff --git a/src/main.rs b/src/main.rs index ad868fc..befe80b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,6 +42,13 @@ enum Commands { #[arg(long)] dump_clir: bool, }, + Run { + /// Paths to the source files + files: Vec, + + #[arg(long)] + dump_clir: bool, + }, } fn parse(files: &Vec) -> Vec { @@ -55,7 +62,7 @@ fn parse(files: &Vec) -> Vec { } fn check(modules: &mut Vec) { - for module in modules { + while let Some(module) = modules.pop() { if let Err(e) = module.type_check() { eprintln!("{}", e); return; @@ -85,12 +92,24 @@ fn main() { } } } - Commands::Compile { files, dump_clir } => { + Commands::Compile { files, dump_clir } | Commands::Run { files, dump_clir } => { let mut jit = jit::JIT::default(); + jit.dump_clir = *dump_clir; for file in files { - match jit.compile(std::fs::read_to_string(file).unwrap().as_str(), *dump_clir) { + match jit.compile_file(file) { Err(e) => eprintln!("{}", e), - Ok(_code) => println!("Compiled {}", file), + Ok(code) => { + println!("Compiled {}", file); + + if let Commands::Run { .. } = cli.command { + let ret = unsafe { + let code_fn: unsafe extern "sysv64" fn() -> i32 = + std::mem::transmute(code); + code_fn() + }; + println!("Main returned {}", ret); + } + } } } } diff --git a/src/typing/error.rs b/src/typing/error.rs index 13026c3..7eea1a5 100644 --- a/src/typing/error.rs +++ b/src/typing/error.rs @@ -4,10 +4,10 @@ use super::UnaryOperator; #[derive(Debug)] pub struct TypeError { - file: Option, - module: ModulePath, - function: Option, - kind: TypeErrorKind, + pub file: Option, + pub module: ModulePath, + pub function: Option, + pub kind: TypeErrorKind, } impl std::fmt::Display for TypeError { @@ -62,7 +62,7 @@ impl TypeErrorBuilder { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum TypeErrorKind { InvalidBinaryOperator { operator: BinaryOperator, diff --git a/src/typing/mod.rs b/src/typing/mod.rs index 15785a0..14acbf1 100644 --- a/src/typing/mod.rs +++ b/src/typing/mod.rs @@ -7,6 +7,9 @@ use crate::ast::*; mod error; use crate::typing::error::{TypeError, TypeErrorKind}; +#[cfg(test)] +mod tests; + #[derive(Debug, PartialEq, Clone)] pub enum Type { /// Not a real type, used for parsing pass @@ -76,8 +79,10 @@ impl FunctionDefinition { } } +pub struct CheckedModule(pub Module); + impl Module { - pub fn type_check(&mut self) -> Result<(), TypeError> { + pub fn type_check(mut self) -> Result { let mut ctx = TypingContext::new(self.path.clone()); ctx.file = self.file.clone(); @@ -96,7 +101,7 @@ impl Module { ctx.variables.clear(); } - Ok(()) + Ok(CheckedModule(self)) } }