restructure parsing and typing modules

* parsing backend submodules
* move typing to its own module
This commit is contained in:
Romain Paquet 2023-07-05 16:14:30 +02:00
parent 43df8c4b0a
commit 99434748fa
16 changed files with 1315 additions and 316 deletions

View file

@ -1,6 +1,6 @@
[package]
name = "kronec"
version = "0.1.0"
version = "0.0.1"
edition = "2021"
[dependencies]

View file

@ -4,10 +4,15 @@ use crate::ast::*;
pub enum Expr {
BinaryExpression(Box<Expr>, BinaryOperator, Box<Expr>),
Identifier(Identifier),
Call(Box<Call>),
// Literals
BooleanLiteral(bool),
IntegerLiteral(i64),
FloatLiteral(f64),
StringLiteral(String),
Call(Box<Call>),
Block(Box<Block>),
/// Last field is either Expr::Block or Expr::IfExpr
IfExpr(Box<Expr>, Box<Block>, Box<Expr>),
}
#[derive(Debug, PartialEq, Clone)]
@ -16,4 +21,7 @@ pub enum BinaryOperator {
Sub,
Mul,
Div,
Modulo,
Equal,
NotEqual,
}

View file

@ -1,17 +1,29 @@
pub mod expr;
pub mod typ;
pub mod module;
use std::path::Path;
pub use crate::ast::expr::{BinaryOperator, Expr};
pub use crate::ast::typ::*;
use crate::ast::module::*;
use crate::typing::Type;
pub type Identifier = String;
// XXX: Is this enum actually useful? Is 3:30 AM btw
#[derive(Debug, PartialEq)]
pub enum Ast {
Module(Module),
}
#[derive(Debug, PartialEq)]
pub enum Definition {
FunctionDefinition(FunctionDefinition),
Expr(Expr),
Module(Vec<Ast>),
Block(Block),
Statement(Statement),
//StructDefinition(StructDefinition),
}
#[derive(Debug, PartialEq)]
pub struct Location {
pub file: Box<Path>,
}
#[derive(Debug, PartialEq)]
@ -20,6 +32,7 @@ pub struct FunctionDefinition {
pub parameters: Vec<Parameter>,
pub return_type: Option<Type>,
pub body: Box<Block>,
pub line_col: (usize, usize),
}
#[derive(Debug, PartialEq)]
@ -30,9 +43,13 @@ pub struct Block {
#[derive(Debug, PartialEq)]
pub enum Statement {
DeclareStatement(Identifier, Expr),
AssignStatement(Identifier, Expr),
ReturnStatement(Option<Expr>),
CallStatement(Call),
UseStatement(ModulePath),
IfStatement(Expr, Block),
WhileStatement(Box<Expr>, Box<Block>),
}
#[derive(Debug, PartialEq)]
@ -41,31 +58,9 @@ pub struct Call {
pub args: Vec<Expr>,
}
pub type Identifier = String;
#[derive(Debug, PartialEq)]
pub struct Parameter {
pub name: Identifier,
pub typ: Type,
}
impl Ast {
/// Type checks the AST and add missing return types.
pub fn check_return_types(&mut self) -> Result<(), TypeError> {
match self {
Ast::Module(defs) => {
for def in defs {
if let Ast::FunctionDefinition { .. } = def {
def.check_return_types()?;
}
}
}
Ast::FunctionDefinition(func) => {
let typ = func.typ(&mut TypeContext::default())?;
func.return_type = Some(typ.clone());
}
_ => unreachable!(),
}
Ok(())
}
}

66
src/ast/module.rs Normal file
View file

@ -0,0 +1,66 @@
use std::path::Path;
use super::Definition;
#[derive(Debug, PartialEq, Clone)]
pub struct ModulePath {
components: Vec<String>,
}
impl std::fmt::Display for ModulePath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.components.join("::")))
}
}
impl From<&Path> for ModulePath {
fn from(path: &Path) -> Self {
let meta = std::fs::metadata(path).unwrap();
ModulePath {
components: path
.components()
.map(|component| match component {
std::path::Component::Normal(n) => {
if meta.is_file() {
n.to_str().unwrap().split(".").nth(0).unwrap().to_string()
} else if meta.is_dir() {
n.to_str().unwrap().to_string()
} else {
// XXX: symlinks?
unreachable!()
}
}
_ => unreachable!(),
})
.collect(),
}
}
}
impl From<&str> for ModulePath {
fn from(string: &str) -> Self {
ModulePath {
components: string.split("::").map(|c| c.to_string()).collect(),
}
}
}
type ImportPath = ModulePath;
#[derive(Debug, PartialEq)]
pub struct Module {
pub file: Option<std::path::PathBuf>,
pub path: ModulePath,
pub definitions: Vec<Definition>,
pub imports: Vec<ImportPath>,
}
impl Module {
pub fn new(path: ModulePath) -> Self {
Module {
path,
file: None,
definitions: vec![],
imports: vec![],
}
}
}

View file

@ -1,158 +0,0 @@
use std::collections::HashMap;
use crate::ast::*;
#[derive(Debug, PartialEq, Clone)]
pub enum Type {
Int,
Float,
Unit,
Str,
Custom(Identifier),
}
impl From<&str> for Type {
fn from(value: &str) -> Self {
match value {
"int" => Type::Int,
"float" => Type::Float,
_ => Type::Custom(Identifier::from(value)),
}
}
}
#[derive(Debug)]
pub enum TypeError {
InvalidBinaryOperator {
operator: BinaryOperator,
lht: Type,
rht: Type,
},
BlockTypeDoesNotMatchFunctionType {
function_name: String,
function_type: Type,
block_type: Type,
},
ReturnTypeDoesNotMatchFunctionType {
function_name: String,
function_type: Type,
ret_type: Type,
},
UnknownIdentifier {
identifier: String,
},
}
#[derive(Default)]
pub struct TypeContext {
pub function: Option<Identifier>,
pub variables: HashMap<Identifier, Type>,
}
/// Trait for nodes which have a deducible type.
pub trait Typ {
/// Try to resolve the type of the node.
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError>;
}
impl Typ for FunctionDefinition {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
let func = self;
let mut ctx = TypeContext {
function: Some(func.name.clone()),
..Default::default()
};
for param in &func.parameters {
ctx.variables.insert(param.name.clone(), param.typ.clone());
}
let body_type = &func.body.typ(&mut ctx)?;
// If the return type is not specified, it is unit.
let func_return_type = match &func.return_type {
Some(typ) => typ,
None => &Type::Unit,
};
// Check coherence with the body's type.
if *func_return_type != *body_type {
return Err(TypeError::BlockTypeDoesNotMatchFunctionType {
function_name: func.name.clone(),
function_type: func_return_type.clone(),
block_type: body_type.clone(),
})
}
// Check coherence with return statements.
for statement in &func.body.statements {
if let Statement::ReturnStatement(value) = statement {
let ret_type = match value {
Some(expr) => expr.typ(&mut ctx)?,
None => Type::Unit,
};
if ret_type != *func_return_type {
return Err(TypeError::ReturnTypeDoesNotMatchFunctionType {
function_name: func.name.clone(),
function_type: func_return_type.clone(),
ret_type,
})
}
}
}
Ok(func_return_type.clone())
}
}
impl Typ for Block {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
// Check if there is an expression at the end of the block.
if let Some(expr) = &self.value {
expr.typ(ctx)
} else {
Ok(Type::Unit)
}
}
}
impl Typ for Expr {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
match self {
Expr::Identifier(identifier) => {
if let Some(typ) = ctx.variables.get(identifier) {
Ok(typ.clone())
} else {
Err(TypeError::UnknownIdentifier {
identifier: identifier.clone(),
})
}
}
Expr::IntegerLiteral(_) => Ok(Type::Int),
Expr::FloatLiteral(_) => Ok(Type::Float),
Expr::BinaryExpression(lhs, op, rhs) => match op {
BinaryOperator::Add
| BinaryOperator::Sub
| BinaryOperator::Mul
| BinaryOperator::Div => {
let left_type = &lhs.typ(ctx)?;
let right_type = &rhs.typ(ctx)?;
match (left_type, right_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
(Type::Float, Type::Int | Type::Float) => Ok(Type::Float),
(Type::Int, Type::Float) => Ok(Type::Float),
(_, _) => Err(TypeError::InvalidBinaryOperator {
operator: op.clone(),
lht: left_type.clone(),
rht: right_type.clone(),
}),
}
}
},
Expr::StringLiteral(_) => Ok(Type::Str),
Expr::Call(call) => {
todo!("resolve call type using ctx");
}
}
}
}

View file

@ -1,8 +1,10 @@
mod ast;
mod parsing;
mod typing;
use clap::{Parser, Subcommand};
use std::fs;
use crate::ast::module::Module;
/// Experimental compiler for krone
#[derive(Parser, Debug)]
@ -16,8 +18,8 @@ struct Cli {
#[derive(Subcommand, Debug)]
enum Commands {
Parse {
/// Path to the source file
file: String,
/// Path to the source files
files: Vec<String>,
/// Dump the AST to stdout
#[arg(long)]
@ -25,7 +27,7 @@ enum Commands {
/// Add missing return types in the AST
#[arg(long)]
complete_ast: bool,
type_check: bool,
},
}
@ -34,25 +36,31 @@ fn main() {
match &cli.command {
Commands::Parse {
file,
files,
dump_ast,
complete_ast,
type_check,
} => {
let source = fs::read_to_string(&file).expect("could not read the source file");
let mut ast = match parsing::parse(&source) {
Ok(ast) => ast,
let paths = files.iter().map(std::path::Path::new);
let modules: Vec<Module> = paths
.map(|path| match parsing::parse_file(&path) {
Ok(module) => module,
Err(e) => panic!("Parsing error: {:#?}", e),
};
})
.collect();
if *complete_ast {
if let Err(e) = ast.check_return_types() {
eprintln!("{:#?}", e);
if *type_check {
for module in &modules {
if let Err(e) = module.type_check() {
eprintln!("{}", e);
return;
}
}
}
if *dump_ast {
println!("{:#?}", &ast);
for module in &modules {
println!("{:#?}", &module);
}
return;
}

View file

@ -0,0 +1,157 @@
use std::iter::Peekable;
use std::str::Chars;
#[derive(Debug)]
pub enum Token {
LeftBracket,
RightBracket,
If,
Else,
Identifier(String),
LeftParenthesis,
RightParenthesis,
Func,
Colon,
While,
Set,
LineComment,
Mul,
Sub,
Add,
Slash,
Modulo,
NotEqual,
Equal,
DoubleEquals,
Exclamation,
NumberLiteral,
}
#[derive(Debug)]
pub enum TokenError {
InvalidToken,
}
pub struct Lexer {
line: usize,
column: usize,
}
impl Lexer {
pub fn new() -> Self {
Self { line: 1, column: 1 }
}
pub fn tokenize(&mut self, input: String) -> Result<Vec<Token>, TokenError> {
let mut tokens: Vec<Token> = Vec::new();
let mut chars = input.chars().peekable();
while let Some(tok_or_err) = self.get_next_token(&mut chars) {
match tok_or_err {
Ok(token) => tokens.push(token),
Err(err) => return Err(err),
};
}
Ok(tokens)
}
fn get_next_token(&mut self, chars: &mut Peekable<Chars>) -> Option<Result<Token, TokenError>> {
if let Some(ch) = chars.next() {
let tok_or_err = match ch {
'(' => Ok(Token::LeftParenthesis),
')' => Ok(Token::RightParenthesis),
'{' => Ok(Token::LeftBracket),
'}' => Ok(Token::RightBracket),
'+' => Ok(Token::Add),
'-' => Ok(Token::Sub),
'*' => Ok(Token::Mul),
'%' => Ok(Token::Modulo),
'/' => {
if let Some('/') = chars.peek() {
chars.next();
let comment = chars.take_while(|c| c != &'\n');
self.column += comment.count() + 1;
Ok(Token::LineComment)
} else {
Ok(Token::Slash)
}
}
'=' => {
if let Some(ch2) = chars.peek() {
match ch2 {
'=' => {
chars.next();
self.column += 1;
Ok(Token::DoubleEquals)
}
' ' => Ok(Token::Equal),
_ => Err(TokenError::InvalidToken),
}
} else {
Ok(Token::Equal)
}
}
'!' => {
if let Some(ch2) = chars.next() {
match ch2 {
'=' => {
self.column += 1;
Ok(Token::NotEqual)
}
_ => Err(TokenError::InvalidToken),
}
} else {
Ok(Token::Exclamation)
}
}
'a'..='z' | 'A'..='Z' => {
let mut word = String::from(ch);
while let Some(ch2) = chars.peek() {
if ch2.is_alphanumeric() {
if let Some(ch2) = chars.next() {
word.push(ch2);
}
} else {
break;
}
}
self.column += word.len();
match word.as_str() {
"func" => Ok(Token::Func),
"if" => Ok(Token::If),
"else" => Ok(Token::Else),
"set" => Ok(Token::Set),
"while" => Ok(Token::While),
_ => Ok(Token::Identifier(word)),
}
}
'0'..='9' | '.' => {
let word = chars
.take_while(|c| c.is_numeric() || c == &'.')
.collect::<String>();
self.column += word.len();
// XXX: handle syntax error in number literals
Ok(Token::NumberLiteral)
}
':' => Ok(Token::Colon),
'\n' => {
self.line += 1;
self.column = 1;
return self.get_next_token(chars);
}
' ' => {
self.column += 1;
return self.get_next_token(chars);
}
'\t' => {
self.column += 8;
return self.get_next_token(chars);
}
_ => Err(TokenError::InvalidToken),
};
self.column += 1;
Some(tok_or_err)
} else {
None
}
}
}

View file

@ -0,0 +1,149 @@
// In progress parser from scratch
use crate::lex::Token;
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug)]
pub enum NodeType {
Document, // This is the root node
LineComment,
FunctionDefinition,
FunctionParam,
VariableName(String),
Type(String),
}
use NodeType::*;
pub struct Node {
kind: NodeType,
parent: Option<Rc<RefCell<Node>>>,
children: Vec<Box<Node>>,
}
impl Node {
fn new() -> Self {
Node::default()
}
fn with_kind(&mut self, kind: NodeType) -> &mut Self {
self.kind = kind;
self
}
fn with_children(&mut self, children: Vec<Node>) -> &mut Self {
for child in children {
self.push_child(child);
}
self
}
fn push_child(&mut self, mut child: Node) {
child.parent = Some(Rc::new(RefCell::new(*self)));
self.children.push(Box::new(child));
}
pub fn print_tree(&self) {
self.print_tree_rec(0);
}
fn print_tree_rec(&self, indent: u8) {
for _ in 1..=indent {
print!(" ");
}
println!("{:?}", self.kind);
for child in &self.children {
child.print_tree_rec(indent + 2);
}
}
}
impl Default for Node {
fn default() -> Self {
Node {
kind: Document,
parent: None,
children: Vec::new(),
}
}
}
impl From<NodeType> for Node {
fn from(value: NodeType) -> Self {
Node {
kind: value,
..Node::default()
}
}
}
#[derive(Debug)]
pub enum SyntaxError {
FuncExpectedIdentifier,
FuncExpectedLeftParenthesisAfterIdentifier,
UnexpectedToken,
}
pub struct Parser {}
impl Parser {
pub fn new() -> Self {
Parser {}
}
pub fn parse_tokens(&mut self, tokens: Vec<Token>) -> Result<Node, SyntaxError> {
let mut tokens = tokens.iter().peekable();
let mut root_node = Node::new();
while let Some(token) = tokens.next() {
let node_or_err = match token {
Token::LineComment => Ok(Node {
kind: LineComment,
..Node::default()
}),
Token::Func => {
let identifier = if let Some(ident) = tokens.next() {
match ident {
Token::Identifier(id) => Some(id),
_ => return Err(SyntaxError::FuncExpectedIdentifier),
}
} else {
None
};
if let Some(Token::LeftParenthesis) = tokens.next() {
} else {
return Err(SyntaxError::FuncExpectedLeftParenthesisAfterIdentifier);
};
let mut params: Vec<Node> = Vec::new();
while let Some(Token::Identifier(_)) = tokens.peek() {
if let Some(Token::Identifier(param_name)) = tokens.next() {
if let Some(Token::Colon) = tokens.next() {
if let Some(Token::Identifier(type_name)) = tokens.next() {
let mut node =
Node::new().with_kind(FunctionParam).with_children(vec![
VariableName(param_name.into()).into(),
Type(type_name.into()).into(),
]);
params.push(*node);
}
}
}
}
let node = Node::from(NodeType::FunctionDefinition).with_children(params);
Ok(*node)
}
_ => Err(SyntaxError::UnexpectedToken),
};
if let Ok(node) = node_or_err {
root_node.push_child(node);
} else {
};
}
Ok(root_node)
}
}

View file

@ -0,0 +1 @@
pub mod pest;

View file

@ -0,0 +1,70 @@
// This file is just a little test of pest.rs
source_file = { SOI ~ module_items ~ EOI }
module_items = { (use_statement | definition)* }
// Statements
statement = { assign_statement | declare_statement | return_statement | call_statement | use_statement | while_statement | if_statement }
declare_statement = { ident ~ "=" ~ expr ~ ";" }
assign_statement = { "set" ~ ident ~ "=" ~ expr ~ ";" }
return_statement = { "return" ~ expr? ~ ";" }
call_statement = { call ~ ";" }
use_statement = { "use" ~ import_path ~ ";" }
while_statement = { "while" ~ expr ~ block ~ ";" }
if_statement = { if_branch ~ ("else" ~ (if_branch | block))? ~ ";" }
if_branch = _{ "if" ~ expr ~ block }
// Module paths
import_path = { ident ~ ("::" ~ ident)* }
// Function call
call = { ident ~ "(" ~ args ~ ")" }
args = { (expr ~ ",")* ~ expr? }
definition = { func_def }
// Function definition
func_def = { "fn" ~ ident ~ "(" ~ parameters ~ ")" ~ typ? ~ block }
parameters = {
(parameter ~ ",")* ~ (parameter)?
}
parameter = { ident ~ ":" ~ typ }
// Operators
infix = _{ add | subtract | multiply | divide | not_equal | equal | modulo }
add = { "+" }
subtract = { "-" }
multiply = { "*" }
divide = { "/" }
modulo = { "%" }
equal = { "==" }
not_equal = { "!=" }
prefix = _{ not }
not = { "!" }
// Expressions
expr = { prefix? ~ atom ~ (infix ~ prefix? ~ atom)* }
atom = _{ call | if_expr | block | literal | ident | "(" ~ expr ~ ")" }
block = { "{" ~ statement* ~ expr? ~ "}" }
if_expr = { "if" ~ expr ~ block ~ "else" ~ (block | if_expr) }
ident = @{ (ASCII_ALPHANUMERIC | "_")+ }
typ = _{ ident }
// Literals
literal = _{ boolean_literal | float_literal | integer_literal | string_literal }
boolean_literal = @{ "true" | "false" }
string_literal = ${ "\"" ~ string_content ~ "\"" }
string_content = @{ char* }
char = {
!("\"" | "\\") ~ ANY
| "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t")
| "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4})
}
integer_literal = @{ ASCII_DIGIT+ }
float_literal = @{ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) ~ "." ~ ASCII_DIGIT* }
WHITESPACE = _{ " " | "\n" | "\t" }
COMMENT = _{ "//" ~ (!NEWLINE ~ ANY)* }

View file

@ -1,15 +1,20 @@
use lazy_static;
use std::fs;
use std::path::Path;
use pest::error::Error;
use pest::iterators::Pair;
use pest::pratt_parser::PrattParser;
use pest::Parser;
use crate::ast::module::{Module, ModulePath};
use crate::ast::*;
use crate::typing::Type;
#[derive(pest_derive::Parser)]
#[grammar = "parsing/grammar.pest"]
#[grammar = "parsing/backend/pest/grammar.pest"]
struct KrParser;
use lazy_static;
lazy_static::lazy_static! {
static ref PRATT_PARSER: PrattParser<Rule> = {
use pest::pratt_parser::{Assoc::*, Op};
@ -17,36 +22,51 @@ lazy_static::lazy_static! {
// Precedence is defined lowest to highest
PrattParser::new()
// Addition and subtract have equal precedence
.op(Op::infix(equal, Left) | Op::infix(not_equal, Left))
.op(Op::infix(add, Left) | Op::infix(subtract, Left))
.op(Op::infix(modulo, Left))
.op(Op::infix(multiply, Left) | Op::infix(divide, Left))
};
}
pub fn parse(source: &str) -> Result<Ast, Error<Rule>> {
let mut definitions: Vec<Ast> = vec![];
pub fn parse_file(path: &Path) -> Result<Module, Error<Rule>> {
let source = fs::read_to_string(&path).expect("could not read source file");
let module_path = ModulePath::from(path);
let mut module = parse_as_module(&source, module_path)?;
module.file = Some(path.to_owned());
Ok(module)
}
pub fn parse_as_module(source: &str, path: ModulePath) -> Result<Module, Error<Rule>> {
let mut pairs = KrParser::parse(Rule::source_file, &source)?;
assert!(pairs.len() == 1);
let module = parse_module(pairs.next().unwrap().into_inner().next().unwrap(), path);
Ok(module)
}
pub fn parse_module(pair: Pair<Rule>, path: ModulePath) -> Module {
assert!(pair.as_rule() == Rule::module_items);
let mut module = Module::new(path);
let pairs = KrParser::parse(Rule::source_file, source)?;
for pair in pairs {
match pair.as_rule() {
Rule::source_file => {
let pairs = pair.into_inner();
for pair in pairs {
match pair.as_rule() {
Rule::definition => {
let definition = parse_definition(pair.into_inner().next().unwrap());
definitions.push(definition);
let def = parse_definition(pair.into_inner().next().unwrap());
module.definitions.push(def);
}
Rule::use_statement => {
let path = parse_import_path(pair.into_inner().next().unwrap());
module.imports.push(path);
}
Rule::EOI => {}
_ => panic!("unexpected rule in source_file: {:?}", pair.as_rule()),
}
}
}
_ => eprintln!("unexpected top-level rule {:?}", pair.as_rule()),
}
}
Ok(Ast::Module(definitions))
module
}
fn parse_block(pair: Pair<Rule>) -> Block {
@ -73,6 +93,12 @@ fn parse_statement(pair: Pair<Rule>) -> Statement {
let expr = parse_expression(pairs.next().unwrap());
Statement::AssignStatement(identifier, expr)
}
Rule::declare_statement => {
let mut pairs = pair.into_inner();
let identifier = pairs.next().unwrap().as_str().to_string();
let expr = parse_expression(pairs.next().unwrap());
Statement::DeclareStatement(identifier, expr)
}
Rule::return_statement => {
let expr = if let Some(pair) = pair.into_inner().next() {
Some(parse_expression(pair))
@ -85,10 +111,32 @@ fn parse_statement(pair: Pair<Rule>) -> Statement {
let call = parse_call(pair.into_inner().next().unwrap());
Statement::CallStatement(call)
}
Rule::use_statement => {
let path = parse_import_path(pair.into_inner().next().unwrap());
Statement::UseStatement(path)
}
Rule::if_statement => {
let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap());
Statement::IfStatement(condition, block)
}
Rule::while_statement => {
let mut pairs = pair.into_inner();
let condition = parse_expression(pairs.next().unwrap());
let block = parse_block(pairs.next().unwrap());
Statement::WhileStatement(Box::new(condition), Box::new(block))
}
_ => unreachable!("unexpected rule '{:?}' in parse_statement", pair.as_rule()),
}
}
type ImportPath = ModulePath;
fn parse_import_path(pair: Pair<Rule>) -> ImportPath {
ModulePath::from(pair.as_str())
}
fn parse_call(pair: Pair<Rule>) -> Call {
let mut pairs = pair.into_inner();
// TODO: support calls on more than identifiers (needs grammar change)
@ -117,9 +165,26 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
.parse()
.unwrap(),
),
Rule::ident => Expr::Identifier(primary.as_str().to_string()),
Rule::expr => parse_expression(primary),
Rule::ident => Expr::Identifier(primary.as_str().to_string()),
Rule::call => Expr::Call(Box::new(parse_call(primary))),
Rule::block => Expr::Block(Box::new(parse_block(primary))),
Rule::if_expr => {
let mut pairs = primary.into_inner();
let condition = parse_expression(pairs.next().unwrap());
let true_block = parse_block(pairs.next().unwrap());
let else_value = parse_expression(pairs.next().unwrap());
Expr::IfExpr(
Box::new(condition),
Box::new(true_block),
Box::new(else_value),
)
}
Rule::boolean_literal => Expr::BooleanLiteral(match primary.as_str() {
"true" => true,
"false" => false,
_ => unreachable!(),
}),
_ => unreachable!(
"Unexpected rule '{:?}' in primary expression",
primary.as_rule()
@ -131,6 +196,9 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
Rule::subtract => BinaryOperator::Sub,
Rule::multiply => BinaryOperator::Mul,
Rule::divide => BinaryOperator::Div,
Rule::modulo => BinaryOperator::Modulo,
Rule::equal => BinaryOperator::Equal,
Rule::not_equal => BinaryOperator::NotEqual,
_ => unreachable!(),
};
Expr::BinaryExpression(Box::new(lhs), operator, Box::new(rhs))
@ -141,14 +209,15 @@ fn parse_expression(pair: Pair<Rule>) -> Expr {
fn parse_parameter(pair: Pair<Rule>) -> Parameter {
assert!(pair.as_rule() == Rule::parameter);
let mut pair = pair.into_inner();
let name: String = pair.next().unwrap().as_str().to_string();
let name = pair.next().unwrap().as_str().to_string();
let typ = Type::from(pair.next().unwrap().as_str());
Parameter { name, typ }
}
fn parse_definition(pair: Pair<Rule>) -> Ast {
fn parse_definition(pair: Pair<Rule>) -> Definition {
match pair.as_rule() {
Rule::func_def => {
let line_col = pair.line_col();
let mut pairs = pair.into_inner();
let name = pairs.next().unwrap().as_str().to_string();
let parameters: Vec<Parameter> = pairs
@ -169,11 +238,12 @@ fn parse_definition(pair: Pair<Rule>) -> Ast {
};
let body = parse_block(pair);
let body = Box::new(body);
Ast::FunctionDefinition(FunctionDefinition {
Definition::FunctionDefinition(FunctionDefinition {
name,
parameters,
return_type,
body,
line_col,
})
}
_ => panic!("unexpected node for definition: {:?}", pair.as_rule()),

View file

@ -0,0 +1,222 @@
use tree_sitter::{self, Language, Parser, TreeCursor};
enum Ast {
FuncDef(FuncDef),
Expr(Expr),
Module(Vec<Ast>),
Block(Vec<Statement>, Option<Expr>),
Statement(Statement),
}
enum BinaryOperator {
Add,
Sub,
Mul,
Div,
}
enum Expr {
BinaryExpression(Box<Expr>, BinaryOperator, Box<Expr>),
}
enum Statement {
AssignStatement(Identifier, Expr),
}
type Identifier = String;
type Type = String;
struct Parameter {
name: Identifier,
typ: Type,
}
struct FuncDef {
name: Identifier,
parameters: Vec<Parameter>,
return_type: Option<Type>,
body: Box<Ast>,
}
#[derive(Debug)]
struct AstError {
message: String,
}
impl AstError {
fn new(message: &str) -> Self {
AstError {
message: message.into(),
}
}
}
extern "C" {
fn tree_sitter_krone() -> Language;
}
struct TreeCursorChildrenIter<'a, A: AsRef<[u8]>> {
source: A,
cursor: &'a mut TreeCursor<'a>,
on_child: bool,
}
impl<'a, A: AsRef<[u8]>> Iterator for TreeCursorChildrenIter<'a, A> {
type Item = Result<Ast, AstError>;
fn next(&mut self) -> Option<Self::Item> {
if self.on_child {
if self.cursor.goto_next_sibling() {
Some(parse_from_cursor(&self.source, self.cursor))
} else {
self.cursor.goto_parent();
None
}
} else {
if self.cursor.goto_first_child() {
self.on_child = true;
Some(parse_from_cursor(&self.source, self.cursor))
} else {
None
}
}
}
}
fn iter_children<'a, A: AsRef<[u8]>>(
source: A,
cursor: &'a mut TreeCursor<'a>,
) -> TreeCursorChildrenIter<'a, A> {
TreeCursorChildrenIter {
source,
cursor,
on_child: false,
}
}
fn parse_from_cursor<'a>(
source: impl AsRef<[u8]>,
cursor: &'a mut TreeCursor<'a>,
) -> Result<Ast, AstError> {
match cursor.node().kind() {
"block" => {
let mut statements = Vec::new();
let mut value = None;
for child in iter_children(source, cursor) {
match child.unwrap() {
Ast::Statement(statement) => {
if value.is_none() {
statements.push(statement);
} else {
return Err(AstError::new(
"cannot have a statement after an expression in a block",
));
// perhaps there is a missing semicolon ;
}
}
Ast::Expr(expr) => value = Some(expr),
_ => return Err(AstError::new("invalid node type")),
};
}
let block = Ast::Block(statements, value);
Ok(block)
}
"function_definition" => {
// 1: name
assert!(cursor.goto_first_child());
assert!(cursor.field_name() == Some("name"));
let name: String = cursor
.node()
.utf8_text(source.as_ref())
.expect("utf8 error")
.into();
// 2: parameters
assert!(cursor.goto_next_sibling());
assert!(cursor.field_name() == Some("parameters"));
let mut parameters = Vec::new();
if cursor.goto_first_child() {
loop {
let param = cursor.node();
assert!(cursor.goto_first_child());
let name = cursor
.node()
.utf8_text(source.as_ref())
.expect("utf8 error")
.into();
assert!(cursor.goto_next_sibling());
let typ = cursor
.node()
.utf8_text(source.as_ref())
.expect("utf8 error")
.into();
cursor.goto_parent();
parameters.push(Parameter { name, typ });
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
// 3: return type
assert!(cursor.goto_next_sibling());
assert!(cursor.field_name() == Some("return_type"));
let return_type = Some(
cursor
.node()
.utf8_text(source.as_ref())
.expect("utf8 error")
.into(),
);
// 4: body
assert!(cursor.goto_next_sibling());
assert!(cursor.field_name() == Some("body"));
let body = parse_from_cursor(source, cursor).unwrap();
let body = Box::new(body);
Ok(Ast::FuncDef(FuncDef {
name,
parameters,
return_type,
body,
}))
}
_ => panic!("unexpected node kind: {}", cursor.node().kind()),
}
}
fn parse_with_tree_sitter(source: impl AsRef<[u8]>) -> Result<Ast, AstError> {
let mut parser = Parser::new();
let language = unsafe { tree_sitter_krone() };
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = tree.walk();
let node = cursor.node();
assert!(node.kind() == "source_file");
let mut top_level_nodes = Vec::new();
for node in iter_children(source, &mut cursor) {
let node = node.unwrap();
match node {
Ast::FuncDef(_) => top_level_nodes.push(node),
_ => panic!("unexpected top-level node type"),
};
}
Ok(Ast::Module(top_level_nodes))
}

View file

@ -1,52 +0,0 @@
// This file is just a little test of pest.rs
source_file = { SOI ~ definition* ~ EOI }
statement = { assign_statement | return_statement | call_statement }
assign_statement = { "set" ~ ident ~ "=" ~ expr ~ ";" }
return_statement = { "return" ~ expr? ~ ";" }
call_statement = { call ~ ";" }
// Function calls
call = { ident ~ "(" ~ args ~ ")" }
args = { (expr ~ ",")* ~ expr? }
definition = { func_def }
func_def = { "fn" ~ ident ~ "(" ~ parameters ~ ")" ~ typ? ~ block }
parameters = {
(parameter ~ ",")* ~ (parameter)?
}
parameter = { ident ~ ":" ~ typ }
block = { "{" ~ statement* ~ expr? ~ "}" }
// Operators
infix = _{ add | subtract | multiply | divide }
add = { "+" }
subtract = { "-" }
multiply = { "*" }
divide = { "/" }
prefix = _{ not }
not = { "!" }
expr = { prefix? ~ atom ~ (infix ~ prefix? ~ atom)* }
atom = _{ call | ident | literal | "(" ~ expr ~ ")" }
ident = @{ (ASCII_ALPHA | "_")+ }
typ = _{ ident }
// Literals
literal = _{ float_literal | integer_literal | string_literal }
string_literal = ${ "\"" ~ string_content ~ "\"" }
string_content = @{ char* }
char = {
!("\"" | "\\") ~ ANY
| "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t")
| "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4})
}
integer_literal = @{ ASCII_DIGIT+ }
float_literal = @{ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) ~ "." ~ ASCII_DIGIT* }
WHITESPACE = _{ " " | "\n" | "\t" }

View file

@ -1,38 +1,5 @@
pub mod pest;
mod backend;
pub use self::pest::parse;
pub use self::backend::pest::{parse_file, parse_module};
mod tests {
#[test]
fn test_addition_function() {
use crate::ast::*;
use crate::parsing::pest::parse;
let source = "fn add(a: int, b: int) int { a + b }";
let ast = Ast::FunctionDefinition(FunctionDefinition {
name: Identifier::from("add"),
parameters: vec![
Parameter {
name: Identifier::from("a"),
typ: Type::Int,
},
Parameter {
name: Identifier::from("b"),
typ: Type::Int,
},
],
return_type: Some(Type::Int),
body: Box::new(Block {
statements: vec![],
value: Some(Expr::BinaryExpression(
Box::new(Expr::Identifier(Identifier::from("a"))),
BinaryOperator::Add,
Box::new(Expr::Identifier(Identifier::from("b"))),
)),
}),
});
assert_eq!(parse(source).unwrap(), Ast::Module(vec![ast]));
}
}
mod tests;

44
src/parsing/tests.rs Normal file
View file

@ -0,0 +1,44 @@
#[test]
fn test_addition_function() {
use crate::parsing::backend::pest::parse_as_module;
use crate::{
ast::module::{Module, ModulePath},
ast::*,
typing::Type,
};
let source = "fn add(a: int, b: int) int { a + b }";
let path = ModulePath::from("test");
let module = parse_as_module(&source, path.clone()).expect("parsing error");
let expected_module = Module {
file: None,
imports: vec![],
definitions: vec![Definition::FunctionDefinition(FunctionDefinition {
name: Identifier::from("add"),
parameters: vec![
Parameter {
name: Identifier::from("a"),
typ: Type::Int,
},
Parameter {
name: Identifier::from("b"),
typ: Type::Int,
},
],
return_type: Some(Type::Int),
body: Box::new(Block {
statements: vec![],
value: Some(Expr::BinaryExpression(
Box::new(Expr::Identifier(Identifier::from("a"))),
BinaryOperator::Add,
Box::new(Expr::Identifier(Identifier::from("b"))),
)),
}),
line_col: (1, 1),
})],
path,
};
assert_eq!(module, expected_module);
}

452
src/typing/mod.rs Normal file
View file

@ -0,0 +1,452 @@
use std::collections::HashMap;
use crate::ast::{
module::{Module, ModulePath},
*,
};
#[derive(Debug, PartialEq, Clone)]
pub enum Type {
Bool,
Int,
Float,
Unit,
Str,
Custom(Identifier),
}
impl From<&str> for Type {
fn from(value: &str) -> Self {
match value {
"int" => Type::Int,
"float" => Type::Float,
_ => Type::Custom(Identifier::from(value)),
}
}
}
impl FunctionDefinition {
fn signature(&self) -> (Vec<Type>, Type) {
let return_type = self.return_type.unwrap_or(Type::Unit);
let params_types = self.parameters.iter().map(|p| p.typ).collect();
(params_types, return_type)
}
}
impl Module {
pub fn type_check(&self) -> Result<(), TypeError> {
let mut ctx = TypeContext::new(self.path);
ctx.file = self.file.clone();
// Register all function signatures
for Definition::FunctionDefinition(func) in &self.definitions {
if let Some(previous) = ctx.functions.insert(func.name.clone(), func.signature()) {
todo!("handle redefinition of function or identical function names across different files");
}
}
// TODO: add signatures of imported functions (even if they have not been checked)
// Type-check the function bodies
for Definition::FunctionDefinition(func) in &self.definitions {
func.typ(&mut ctx)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct TypeError {
file: Option<std::path::PathBuf>,
module: ModulePath,
function: Option<String>,
kind: TypeErrorKind,
}
impl std::fmt::Display for TypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Error\n")?;
if let Some(path) = &self.file {
f.write_fmt(format_args!(" in file {}\n", path.display()))?;
}
f.write_fmt(format_args!(" in module {}\n", self.module))?;
if let Some(name) = &self.function {
f.write_fmt(format_args!(" in function {}\n", name))?;
}
f.write_fmt(format_args!("{:#?}", self.kind))?;
Ok(())
}
}
#[derive(Default)]
struct TypeErrorBuilder {
file: Option<std::path::PathBuf>,
module: Option<ModulePath>,
function: Option<String>,
kind: Option<TypeErrorKind>,
}
impl TypeError {
fn builder() -> TypeErrorBuilder {
TypeErrorBuilder::default()
}
}
impl TypeErrorBuilder {
fn context(mut self, ctx: &TypeContext) -> Self {
self.file = ctx.file.clone();
self.module = Some(ctx.module.clone());
self.function = ctx.function.clone();
self
}
fn kind(mut self, kind: TypeErrorKind) -> Self {
self.kind = Some(kind);
self
}
fn build(self) -> TypeError {
TypeError {
file: self.file,
module: self.module.expect("TypeError builder is missing module"),
function: self.function,
kind: self.kind.expect("TypeError builder is missing kind"),
}
}
}
#[derive(Debug)]
pub enum TypeErrorKind {
InvalidBinaryOperator {
operator: BinaryOperator,
lht: Type,
rht: Type,
},
BlockTypeDoesNotMatchFunctionType {
block_type: Type,
function_type: Type,
},
ReturnTypeDoesNotMatchFunctionType {
function_type: Type,
return_type: Type,
},
UnknownIdentifier {
identifier: String,
},
AssignmentMismatch {
lht: Type,
rht: Type,
},
AssignUndeclared,
VariableRedeclaration,
ReturnStatementsMismatch,
UnknownFunctionCalled(Identifier),
WrongFunctionArguments,
ConditionIsNotBool,
IfElseMismatch,
}
pub struct TypeContext {
pub file: Option<std::path::PathBuf>,
pub module: ModulePath,
pub function: Option<Identifier>,
pub functions: HashMap<Identifier, (Vec<Type>, Type)>,
pub variables: HashMap<Identifier, Type>,
}
impl TypeContext {
pub fn new(path: ModulePath) -> Self {
TypeContext {
file: None,
module: path,
function: None,
functions: Default::default(),
variables: Default::default(),
}
}
}
/// Trait for nodes which have a deducible type.
pub trait Typ {
/// Try to resolve the type of the node.
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError>;
}
impl Typ for FunctionDefinition {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
let func = self;
ctx.function = Some(func.name.clone());
for param in &func.parameters {
ctx.variables.insert(param.name.clone(), param.typ.clone());
}
let body_type = &func.body.typ(ctx)?;
// If the return type is not specified, it is unit.
let func_return_type = match &func.return_type {
Some(typ) => typ,
None => &Type::Unit,
};
// Check coherence with the body's type.
if *func_return_type != *body_type {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::BlockTypeDoesNotMatchFunctionType {
block_type: body_type.clone(),
function_type: func_return_type.clone(),
})
.build());
}
// Check coherence with return statements.
for statement in &func.body.statements {
if let Statement::ReturnStatement(value) = statement {
let ret_type = match value {
Some(expr) => expr.typ(ctx)?,
None => Type::Unit,
};
if ret_type != *func_return_type {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnTypeDoesNotMatchFunctionType {
function_type: func_return_type.clone(),
return_type: ret_type,
})
.build());
}
}
}
Ok(func_return_type.clone())
}
}
impl Typ for Block {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
let mut return_typ: Option<Type> = None;
// Check declarations and assignments.
for statement in &self.statements {
match statement {
Statement::DeclareStatement(ident, expr) => {
let typ = expr.typ(ctx)?;
if let Some(_typ) = ctx.variables.insert(ident.clone(), typ) {
// TODO: Shadowing? (illegal for now)
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::VariableRedeclaration)
.build());
}
}
Statement::AssignStatement(ident, expr) => {
let rhs_typ = expr.typ(ctx)?;
let Some(lhs_typ) = ctx.variables.get(ident) else {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::AssignUndeclared)
.build());
};
// Ensure same type on both sides.
if rhs_typ != *lhs_typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::AssignmentMismatch {
lht: lhs_typ.clone(),
rht: rhs_typ.clone(),
})
.build());
}
}
Statement::ReturnStatement(maybe_expr) => {
let expr_typ = if let Some(expr) = maybe_expr {
expr.typ(ctx)?
} else {
Type::Unit
};
if let Some(typ) = &return_typ {
if expr_typ != *typ {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ReturnStatementsMismatch)
.build());
}
} else {
return_typ = Some(expr_typ);
}
}
Statement::CallStatement(call) => {
call.typ(ctx)?;
}
Statement::UseStatement(_path) => {
// TODO: import the signatures (and types)
}
Statement::IfStatement(cond, block) => {
if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
}
block.typ(ctx)?;
}
Statement::WhileStatement(cond, block) => {
if cond.typ(ctx)? != Type::Bool {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build());
}
block.typ(ctx)?;
}
}
}
// Check if there is an expression at the end of the block.
if let Some(expr) = &self.value {
expr.typ(ctx)
} else {
Ok(Type::Unit)
}
// TODO/FIXME: find a way to return `return_typ` so that the
// top-level block (the function) can check if this return type
// (and eventually those from other block) matches the type of
// the function.
}
}
impl Typ for Call {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
match &self.callee {
Expr::Identifier(ident) => {
let signature = match ctx.functions.get(ident) {
Some(sgn) => sgn.clone(),
None => {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::UnknownFunctionCalled(ident.clone()))
.build())
}
};
let (params_types, func_type) = signature;
// Collect arg types.
let mut args_types: Vec<Type> = vec![];
for arg in &self.args {
let typ = arg.typ(ctx)?;
args_types.push(typ.clone());
}
if args_types == *params_types {
Ok(func_type.clone())
} else {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::WrongFunctionArguments)
.build())
}
}
_ => unimplemented!("cannot call on expression other than identifier"),
}
}
}
impl Typ for Expr {
fn typ(&self, ctx: &mut TypeContext) -> Result<Type, TypeError> {
match self {
Expr::Identifier(identifier) => {
if let Some(typ) = ctx.variables.get(identifier) {
Ok(typ.clone())
} else {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::UnknownIdentifier {
identifier: identifier.clone(),
})
.build())
}
}
Expr::BooleanLiteral(_) => Ok(Type::Bool),
Expr::IntegerLiteral(_) => Ok(Type::Int),
Expr::FloatLiteral(_) => Ok(Type::Float),
Expr::BinaryExpression(lhs, op, rhs) => match op {
BinaryOperator::Add
| BinaryOperator::Sub
| BinaryOperator::Mul
| BinaryOperator::Div => {
let left_type = &lhs.typ(ctx)?;
let right_type = &rhs.typ(ctx)?;
match (left_type, right_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
(Type::Float, Type::Float) => Ok(Type::Float),
(_, _) => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: left_type.clone(),
rht: right_type.clone(),
})
.build()),
}
}
BinaryOperator::Equal | BinaryOperator::NotEqual => {
let lhs_type = lhs.typ(ctx)?;
let rhs_type = rhs.typ(ctx)?;
if lhs_type != rhs_type {
return Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: lhs_type.clone(),
rht: rhs_type.clone(),
})
.build());
}
Ok(Type::Bool)
}
BinaryOperator::Modulo => {
let lhs_type = lhs.typ(ctx)?;
let rhs_type = lhs.typ(ctx)?;
match (&lhs_type, &rhs_type) {
(Type::Int, Type::Int) => Ok(Type::Int),
_ => Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::InvalidBinaryOperator {
operator: op.clone(),
lht: lhs_type.clone(),
rht: rhs_type.clone(),
})
.build()),
}
}
},
Expr::StringLiteral(_) => Ok(Type::Str),
Expr::Call(call) => call.typ(ctx),
Expr::Block(block) => block.typ(ctx),
Expr::IfExpr(cond, true_block, else_value) => {
if cond.typ(ctx)? != Type::Bool {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::ConditionIsNotBool)
.build())
} else {
let true_block_type = true_block.typ(ctx)?;
let else_type = else_value.typ(ctx)?;
if true_block_type != else_type {
Err(TypeError::builder()
.context(ctx)
.kind(TypeErrorKind::IfElseMismatch)
.build())
} else {
Ok(true_block_type.clone())
}
}
}
}
}
}