Add return-type checking and fix string literal type-info bug.

This commit is contained in:
Jesse Brault 2026-03-10 12:42:53 -05:00
parent 705436ba61
commit 7de866cf9d
6 changed files with 98 additions and 8 deletions

View File

@ -3,7 +3,10 @@ use crate::ast::ir_builder::IrBuilder;
use crate::diagnostic::Diagnostic; use crate::diagnostic::Diagnostic;
use crate::ir::ir_return::IrReturn; use crate::ir::ir_return::IrReturn;
use crate::ir::ir_statement::IrStatement; use crate::ir::ir_statement::IrStatement;
use crate::symbol::FunctionSymbol;
use crate::symbol_table::SymbolTable; use crate::symbol_table::SymbolTable;
use std::cell::RefCell;
use std::rc::Rc;
pub struct ExpressionStatement { pub struct ExpressionStatement {
expression: Box<Expression>, expression: Box<Expression>,
@ -31,8 +34,30 @@ impl ExpressionStatement {
self.expression.check_name_usages(symbol_table) self.expression.check_name_usages(symbol_table)
} }
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> { pub fn type_check(
self.expression.type_check(symbol_table) &mut self,
symbol_table: &SymbolTable,
is_last: bool,
function_symbol: &Rc<RefCell<FunctionSymbol>>,
) -> Result<(), Vec<Diagnostic>> {
self.expression.type_check(symbol_table)?;
if is_last {
let expression_type = self.expression.type_info();
let borrowed_symbol = function_symbol.borrow();
let return_type = borrowed_symbol.return_type_info();
if !return_type.is_assignable_from(expression_type) {
return Err(vec![Diagnostic::new(
&format!(
"Incompatible type on return expression: expected {} but found {}",
return_type, expression_type
),
self.expression.source_range().start(),
self.expression.source_range().end(),
)]);
}
}
Ok(())
} }
pub fn to_ir( pub fn to_ir(

View File

@ -168,12 +168,19 @@ impl Function {
.collect(), .collect(),
); );
let function_symbol = self.function_symbol.as_ref().unwrap();
let statements_len = self.statements.len();
// statements // statements
diagnostics.append( diagnostics.append(
&mut self &mut self
.statements .statements
.iter_mut() .iter_mut()
.map(|statement| statement.type_check(symbol_table)) .enumerate()
.map(|(i, statement)| {
let is_last = i == statements_len - 1;
statement.type_check(symbol_table, is_last, function_symbol)
})
.filter_map(Result::err) .filter_map(Result::err)
.flatten() .flatten()
.collect(), .collect(),

View File

@ -2,7 +2,10 @@ use crate::ast::expression_statement::ExpressionStatement;
use crate::ast::ir_builder::IrBuilder; use crate::ast::ir_builder::IrBuilder;
use crate::ast::let_statement::LetStatement; use crate::ast::let_statement::LetStatement;
use crate::diagnostic::Diagnostic; use crate::diagnostic::Diagnostic;
use crate::symbol::FunctionSymbol;
use crate::symbol_table::SymbolTable; use crate::symbol_table::SymbolTable;
use std::cell::RefCell;
use std::rc::Rc;
pub enum Statement { pub enum Statement {
Let(LetStatement), Let(LetStatement),
@ -31,11 +34,16 @@ impl Statement {
} }
} }
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> { pub fn type_check(
&mut self,
symbol_table: &SymbolTable,
is_last: bool,
function_symbol: &Rc<RefCell<FunctionSymbol>>,
) -> Result<(), Vec<Diagnostic>> {
match self { match self {
Statement::Let(let_statement) => let_statement.type_check(symbol_table), Statement::Let(let_statement) => let_statement.type_check(symbol_table),
Statement::Expression(expression_statement) => { Statement::Expression(expression_statement) => {
expression_statement.type_check(symbol_table) expression_statement.type_check(symbol_table, is_last, function_symbol)
} }
} }
} }

View File

@ -9,7 +9,7 @@ pub struct StringLiteral {
impl StringLiteral { impl StringLiteral {
pub fn new(content: &str, source_range: SourceRange) -> Self { pub fn new(content: &str, source_range: SourceRange) -> Self {
const TYPE_INFO: TypeInfo = TypeInfo::Integer; const TYPE_INFO: TypeInfo = TypeInfo::String;
Self { Self {
content: content.into(), content: content.into(),
source_range, source_range,

View File

@ -68,10 +68,16 @@ impl TypeInfo {
} }
} }
pub fn add_result(&self, _rhs: &Self) -> TypeInfo { pub fn add_result(&self, rhs: &Self) -> TypeInfo {
match self { match self {
TypeInfo::Integer => match rhs {
TypeInfo::Integer => TypeInfo::Integer, TypeInfo::Integer => TypeInfo::Integer,
TypeInfo::String => TypeInfo::String, TypeInfo::String => TypeInfo::String,
_ => panic!(
"Adding things other than integers/strings to integer not yet supported."
),
},
TypeInfo::String => TypeInfo::String,
_ => panic!("Adding things other than integers and strings not yet supported"), _ => panic!("Adding things other than integers and strings not yet supported"),
} }
} }

View File

@ -137,3 +137,47 @@ mod e2e_tests {
assert_result("fn sub() -> Int 3 - 2 end", "sub", &vec![], Value::Int(1)) assert_result("fn sub() -> Int 3 - 2 end", "sub", &vec![], Value::Int(1))
} }
} }
#[cfg(test)]
mod diagnostic_tests {
use dmc_lib::diagnostic::Diagnostic;
use dmc_lib::parser::parse_compilation_unit;
use dmc_lib::symbol_table::SymbolTable;
fn get_diagnostics(input: &str) -> Vec<Diagnostic> {
let parse_result = parse_compilation_unit(input);
let mut compilation_unit = match parse_result {
Ok(compilation_unit) => compilation_unit,
Err(diagnostics) => {
return diagnostics;
}
};
let mut symbol_table = SymbolTable::new();
match compilation_unit.gather_declared_names(&mut symbol_table) {
Ok(_) => {}
Err(diagnostics) => {
return diagnostics;
}
}
match compilation_unit.check_name_usages(&symbol_table) {
Ok(_) => {}
Err(diagnostics) => {
return diagnostics;
}
}
match compilation_unit.type_check(&symbol_table) {
Ok(_) => vec![],
Err(diagnostics) => diagnostics,
}
}
#[test]
fn wrong_return_type() {
let diagnostics = get_diagnostics("fn main() -> String 42 end");
assert_eq!(diagnostics.len(), 1);
}
}