Remove module_level_declaration ast node.

This commit is contained in:
Jesse Brault 2026-03-11 12:14:50 -05:00
parent 80b6b96aeb
commit 9790ec6ca6
6 changed files with 156 additions and 121 deletions

View File

@ -1,6 +1,8 @@
use crate::ast::field::Field;
use crate::ast::function::Function;
use crate::diagnostic::Diagnostic;
use crate::source_range::SourceRange;
use crate::symbol_table::SymbolTable;
use std::rc::Rc;
pub struct Class {
@ -24,4 +26,19 @@ impl Class {
functions,
}
}
pub fn gather_declared_names(
&mut self,
symbol_table: &mut SymbolTable,
) -> Result<(), Vec<Diagnostic>> {
todo!()
}
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
todo!()
}
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
todo!()
}
}

View File

@ -1,19 +1,39 @@
use crate::ast::module_level_declaration::ModuleLevelDeclaration;
use crate::ast::class::Class;
use crate::ast::extern_function::ExternFunction;
use crate::ast::function::Function;
use crate::diagnostic::Diagnostic;
use crate::ir::ir_function::IrFunction;
use crate::symbol_table::SymbolTable;
pub struct CompilationUnit {
declarations: Vec<ModuleLevelDeclaration>,
functions: Vec<Function>,
extern_functions: Vec<ExternFunction>,
classes: Vec<Class>,
}
impl CompilationUnit {
pub fn new(declarations: Vec<ModuleLevelDeclaration>) -> Self {
Self { declarations }
pub fn new(
functions: Vec<Function>,
extern_functions: Vec<ExternFunction>,
classes: Vec<Class>,
) -> Self {
Self {
functions,
extern_functions,
classes,
}
}
pub fn declarations(&self) -> &[ModuleLevelDeclaration] {
&self.declarations
pub fn functions(&self) -> &[Function] {
&self.functions
}
pub fn extern_functions(&self) -> &[ExternFunction] {
&self.extern_functions
}
pub fn classes(&self) -> &[Class] {
&self.classes
}
pub fn gather_declared_names(
@ -22,13 +42,27 @@ impl CompilationUnit {
) -> Result<(), Vec<Diagnostic>> {
symbol_table.push_scope("compilation_unit_scope");
let diagnostics: Vec<Diagnostic> = self
.declarations
let mut diagnostics: Vec<Diagnostic> = vec![];
self.functions
.iter_mut()
.map(|declaration| declaration.gather_declared_names(symbol_table))
.map(|f| f.gather_declared_names(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
.for_each(|diagnostic| diagnostics.push(diagnostic));
self.extern_functions
.iter_mut()
.map(|f| f.gather_declared_names(symbol_table))
.filter_map(Result::err)
.flatten()
.for_each(|diagnostic| diagnostics.push(diagnostic));
self.classes
.iter_mut()
.map(|c| c.gather_declared_names(symbol_table))
.filter_map(Result::err)
.flatten()
.for_each(|diagnostic| diagnostics.push(diagnostic));
symbol_table.pop_scope();
@ -40,13 +74,29 @@ impl CompilationUnit {
}
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
let diagnostics: Vec<Diagnostic> = self
.declarations
let mut diagnostics: Vec<Diagnostic> = vec![];
self.functions
.iter_mut()
.map(|declaration| declaration.check_name_usages(symbol_table))
.map(|f| f.check_name_usages(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
.for_each(|diagnostic| diagnostics.push(diagnostic));
self.extern_functions
.iter_mut()
.map(|f| f.check_name_usages(symbol_table))
.filter_map(Result::err)
.flatten()
.for_each(|diagnostic| diagnostics.push(diagnostic));
self.classes
.iter_mut()
.map(|c| c.check_name_usages(symbol_table))
.filter_map(Result::err)
.flatten()
.for_each(|diagnostic| diagnostics.push(diagnostic));
if diagnostics.is_empty() {
Ok(())
} else {
@ -55,13 +105,29 @@ impl CompilationUnit {
}
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
let diagnostics: Vec<Diagnostic> = self
.declarations
let mut diagnostics: Vec<Diagnostic> = vec![];
self.functions
.iter_mut()
.map(|declaration| declaration.type_check(symbol_table))
.map(|f| f.type_check(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
.for_each(|diagnostic| diagnostics.push(diagnostic));
self.extern_functions
.iter_mut()
.map(|f| f.type_check(symbol_table))
.filter_map(Result::err)
.flatten()
.for_each(|diagnostic| diagnostics.push(diagnostic));
self.classes
.iter_mut()
.map(|c| c.type_check(symbol_table))
.filter_map(Result::err)
.flatten()
.for_each(|diagnostic| diagnostics.push(diagnostic));
if diagnostics.is_empty() {
Ok(())
} else {
@ -70,12 +136,9 @@ impl CompilationUnit {
}
pub fn to_ir(&self, symbol_table: &SymbolTable) -> Vec<IrFunction> {
let mut ir_functions = vec![];
for declaration in &self.declarations {
if let ModuleLevelDeclaration::Function(function) = declaration {
ir_functions.push(function.to_ir(symbol_table));
}
}
ir_functions
self.functions
.iter()
.map(|f| f.to_ir(symbol_table))
.collect()
}
}

View File

@ -13,7 +13,6 @@ pub mod identifier;
pub mod integer_literal;
pub mod ir_builder;
pub mod let_statement;
pub mod module_level_declaration;
pub mod negative_expression;
pub mod parameter;
pub mod statement;

View File

@ -1,54 +0,0 @@
use crate::ast::class::Class;
use crate::ast::extern_function::ExternFunction;
use crate::ast::function::Function;
use crate::diagnostic::Diagnostic;
use crate::symbol_table::SymbolTable;
pub enum ModuleLevelDeclaration {
Function(Function),
ExternFunction(ExternFunction),
Class(Class),
}
impl ModuleLevelDeclaration {
pub fn gather_declared_names(
&mut self,
symbol_table: &mut SymbolTable,
) -> Result<(), Vec<Diagnostic>> {
match self {
ModuleLevelDeclaration::Function(function) => {
function.gather_declared_names(symbol_table)
}
ModuleLevelDeclaration::ExternFunction(extern_function) => {
extern_function.gather_declared_names(symbol_table)
}
ModuleLevelDeclaration::Class(class) => {
todo!()
}
}
}
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
match self {
ModuleLevelDeclaration::Function(function) => function.check_name_usages(symbol_table),
ModuleLevelDeclaration::ExternFunction(extern_function) => {
extern_function.check_name_usages(symbol_table)
}
ModuleLevelDeclaration::Class(class) => {
todo!()
}
}
}
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
match self {
ModuleLevelDeclaration::Function(function) => function.type_check(symbol_table),
ModuleLevelDeclaration::ExternFunction(extern_function) => {
extern_function.type_check(symbol_table)
}
ModuleLevelDeclaration::Class(class) => {
todo!()
}
}
}
}

View File

@ -121,7 +121,6 @@ impl Display for IrBlock {
#[cfg(test)]
mod tests {
use crate::ast::module_level_declaration::ModuleLevelDeclaration;
use crate::parser::parse_compilation_unit;
use crate::symbol_table::SymbolTable;
@ -139,22 +138,24 @@ mod tests {
)
.unwrap();
let mut symbol_table = SymbolTable::new();
compilation_unit.gather_declared_names(&mut symbol_table);
compilation_unit.check_name_usages(&mut symbol_table);
compilation_unit.type_check(&mut symbol_table);
compilation_unit
.gather_declared_names(&mut symbol_table)
.expect("gather failed");
compilation_unit
.check_name_usages(&mut symbol_table)
.expect("name check failed");
compilation_unit
.type_check(&mut symbol_table)
.expect("type check failed");
let main = compilation_unit
.declarations()
.functions()
.iter()
.find(|d| matches!(d, ModuleLevelDeclaration::Function(_)))
.find(|f| f.declared_name() == "main")
.unwrap();
if let ModuleLevelDeclaration::Function(main) = main {
let mut main_ir = main.to_ir(&symbol_table);
let (register_assignments, _) = main_ir.assign_registers(2);
assert_eq!(register_assignments.len(), 4);
} else {
unreachable!()
}
let mut main_ir = main.to_ir(&symbol_table);
let (register_assignments, _) = main_ir.assign_registers(2);
assert_eq!(register_assignments.len(), 4);
}
}

View File

@ -11,7 +11,6 @@ use crate::ast::function::Function;
use crate::ast::identifier::Identifier;
use crate::ast::integer_literal::IntegerLiteral;
use crate::ast::let_statement::LetStatement;
use crate::ast::module_level_declaration::ModuleLevelDeclaration;
use crate::ast::negative_expression::NegativeExpression;
use crate::ast::parameter::Parameter;
use crate::ast::statement::Statement;
@ -189,16 +188,24 @@ impl<'a> Parser<'a> {
}
pub fn compilation_unit(&mut self) -> Result<CompilationUnit, Vec<Diagnostic>> {
let mut declarations = vec![];
let mut functions: Vec<Function> = vec![];
let mut extern_functions: Vec<ExternFunction> = vec![];
let mut classes: Vec<Class> = vec![];
let mut diagnostics = vec![];
self.advance(); // get started
while self.current.is_some() {
let current = self.get_current();
match current.kind() {
TokenKind::Fn | TokenKind::Extern | TokenKind::Class => {
let declaration_result = self.module_level_declaration();
match declaration_result {
Ok(declaration) => declarations.push(declaration),
match self.module_level_declaration(
&mut functions,
&mut extern_functions,
&mut classes,
) {
Ok(_) => {}
Err(mut declaration_diagnostics) => {
diagnostics.append(&mut declaration_diagnostics)
}
@ -219,19 +226,27 @@ impl<'a> Parser<'a> {
}
}
if diagnostics.is_empty() {
Ok(CompilationUnit::new(declarations))
Ok(CompilationUnit::new(functions, extern_functions, classes))
} else {
Err(diagnostics)
}
}
fn module_level_declaration(&mut self) -> Result<ModuleLevelDeclaration, Vec<Diagnostic>> {
fn module_level_declaration(
&mut self,
functions: &mut Vec<Function>,
extern_functions: &mut Vec<ExternFunction>,
classes: &mut Vec<Class>,
) -> Result<(), Vec<Diagnostic>> {
let current = self.get_current();
match current.kind() {
TokenKind::Fn => {
let function_result = self.function();
match function_result {
Ok(function) => Ok(ModuleLevelDeclaration::Function(function)),
Ok(function) => {
functions.push(function);
Ok(())
}
Err(function_diagnostics) => Err(function_diagnostics),
}
}
@ -239,13 +254,17 @@ impl<'a> Parser<'a> {
let extern_function_result = self.extern_function();
match extern_function_result {
Ok(extern_function) => {
Ok(ModuleLevelDeclaration::ExternFunction(extern_function))
extern_functions.push(extern_function);
Ok(())
}
Err(extern_function_diagnostics) => Err(extern_function_diagnostics),
}
}
TokenKind::Class => match self.class() {
Ok(class) => Ok(ModuleLevelDeclaration::Class(class)),
Ok(class) => {
classes.push(class);
Ok(())
}
Err(class_diagnostics) => Err(class_diagnostics),
},
_ => unreachable!(),
@ -865,29 +884,19 @@ mod concrete_tests {
compilation_unit: &'a CompilationUnit,
function_name: &str,
) -> &'a Function {
let declarations = compilation_unit.declarations();
for declaration in declarations {
match declaration {
ModuleLevelDeclaration::Function(function) => {
if function.declared_name() == function_name {
return function;
}
}
_ => {}
}
}
panic!("Function {} not found", function_name)
compilation_unit
.functions()
.iter()
.find(|f| f.declared_name() == function_name)
.unwrap()
}
#[test]
fn parses_extern_fn() {
let compilation_unit = assert_compilation_unit("extern fn println() -> Void");
let declarations = compilation_unit.declarations();
assert_eq!(declarations.len(), 1);
let extern_function = match &declarations[0] {
ModuleLevelDeclaration::ExternFunction(extern_function) => extern_function,
_ => panic!(),
};
let extern_functions = compilation_unit.extern_functions();
assert_eq!(extern_functions.len(), 1);
let extern_function = &extern_functions[0];
assert_eq!(extern_function.declared_name(), "println");
}