From 7de866cf9d80819820b0de2f2095d08084b6ec8e Mon Sep 17 00:00:00 2001 From: Jesse Brault Date: Tue, 10 Mar 2026 12:42:53 -0500 Subject: [PATCH] Add return-type checking and fix string literal type-info bug. --- dmc-lib/src/ast/expression_statement.rs | 29 ++++++++++++++-- dmc-lib/src/ast/function.rs | 9 ++++- dmc-lib/src/ast/statement.rs | 12 +++++-- dmc-lib/src/ast/string_literal.rs | 2 +- dmc-lib/src/type_info.rs | 10 ++++-- e2e-tests/src/lib.rs | 44 +++++++++++++++++++++++++ 6 files changed, 98 insertions(+), 8 deletions(-) diff --git a/dmc-lib/src/ast/expression_statement.rs b/dmc-lib/src/ast/expression_statement.rs index d721fad..db2334f 100644 --- a/dmc-lib/src/ast/expression_statement.rs +++ b/dmc-lib/src/ast/expression_statement.rs @@ -3,7 +3,10 @@ use crate::ast::ir_builder::IrBuilder; use crate::diagnostic::Diagnostic; use crate::ir::ir_return::IrReturn; use crate::ir::ir_statement::IrStatement; +use crate::symbol::FunctionSymbol; use crate::symbol_table::SymbolTable; +use std::cell::RefCell; +use std::rc::Rc; pub struct ExpressionStatement { expression: Box, @@ -31,8 +34,30 @@ impl ExpressionStatement { self.expression.check_name_usages(symbol_table) } - pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec> { - self.expression.type_check(symbol_table) + pub fn type_check( + &mut self, + symbol_table: &SymbolTable, + is_last: bool, + function_symbol: &Rc>, + ) -> Result<(), Vec> { + self.expression.type_check(symbol_table)?; + + if is_last { + let expression_type = self.expression.type_info(); + let borrowed_symbol = function_symbol.borrow(); + let return_type = borrowed_symbol.return_type_info(); + if !return_type.is_assignable_from(expression_type) { + return Err(vec![Diagnostic::new( + &format!( + "Incompatible type on return expression: expected {} but found {}", + return_type, expression_type + ), + self.expression.source_range().start(), + self.expression.source_range().end(), + )]); + } + } + Ok(()) } pub fn to_ir( diff --git a/dmc-lib/src/ast/function.rs b/dmc-lib/src/ast/function.rs index 57b19f9..efeead8 100644 --- a/dmc-lib/src/ast/function.rs +++ b/dmc-lib/src/ast/function.rs @@ -168,12 +168,19 @@ impl Function { .collect(), ); + let function_symbol = self.function_symbol.as_ref().unwrap(); + let statements_len = self.statements.len(); + // statements diagnostics.append( &mut self .statements .iter_mut() - .map(|statement| statement.type_check(symbol_table)) + .enumerate() + .map(|(i, statement)| { + let is_last = i == statements_len - 1; + statement.type_check(symbol_table, is_last, function_symbol) + }) .filter_map(Result::err) .flatten() .collect(), diff --git a/dmc-lib/src/ast/statement.rs b/dmc-lib/src/ast/statement.rs index 767a8f5..d251785 100644 --- a/dmc-lib/src/ast/statement.rs +++ b/dmc-lib/src/ast/statement.rs @@ -2,7 +2,10 @@ use crate::ast::expression_statement::ExpressionStatement; use crate::ast::ir_builder::IrBuilder; use crate::ast::let_statement::LetStatement; use crate::diagnostic::Diagnostic; +use crate::symbol::FunctionSymbol; use crate::symbol_table::SymbolTable; +use std::cell::RefCell; +use std::rc::Rc; pub enum Statement { Let(LetStatement), @@ -31,11 +34,16 @@ impl Statement { } } - pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec> { + pub fn type_check( + &mut self, + symbol_table: &SymbolTable, + is_last: bool, + function_symbol: &Rc>, + ) -> Result<(), Vec> { match self { Statement::Let(let_statement) => let_statement.type_check(symbol_table), Statement::Expression(expression_statement) => { - expression_statement.type_check(symbol_table) + expression_statement.type_check(symbol_table, is_last, function_symbol) } } } diff --git a/dmc-lib/src/ast/string_literal.rs b/dmc-lib/src/ast/string_literal.rs index 2956ac2..5192684 100644 --- a/dmc-lib/src/ast/string_literal.rs +++ b/dmc-lib/src/ast/string_literal.rs @@ -9,7 +9,7 @@ pub struct StringLiteral { impl StringLiteral { pub fn new(content: &str, source_range: SourceRange) -> Self { - const TYPE_INFO: TypeInfo = TypeInfo::Integer; + const TYPE_INFO: TypeInfo = TypeInfo::String; Self { content: content.into(), source_range, diff --git a/dmc-lib/src/type_info.rs b/dmc-lib/src/type_info.rs index a48081c..28d75cf 100644 --- a/dmc-lib/src/type_info.rs +++ b/dmc-lib/src/type_info.rs @@ -68,9 +68,15 @@ impl TypeInfo { } } - pub fn add_result(&self, _rhs: &Self) -> TypeInfo { + pub fn add_result(&self, rhs: &Self) -> TypeInfo { match self { - TypeInfo::Integer => TypeInfo::Integer, + TypeInfo::Integer => match rhs { + TypeInfo::Integer => TypeInfo::Integer, + TypeInfo::String => TypeInfo::String, + _ => panic!( + "Adding things other than integers/strings to integer not yet supported." + ), + }, TypeInfo::String => TypeInfo::String, _ => panic!("Adding things other than integers and strings not yet supported"), } diff --git a/e2e-tests/src/lib.rs b/e2e-tests/src/lib.rs index ba3cec5..1c29b8b 100644 --- a/e2e-tests/src/lib.rs +++ b/e2e-tests/src/lib.rs @@ -137,3 +137,47 @@ mod e2e_tests { assert_result("fn sub() -> Int 3 - 2 end", "sub", &vec![], Value::Int(1)) } } + +#[cfg(test)] +mod diagnostic_tests { + use dmc_lib::diagnostic::Diagnostic; + use dmc_lib::parser::parse_compilation_unit; + use dmc_lib::symbol_table::SymbolTable; + + fn get_diagnostics(input: &str) -> Vec { + let parse_result = parse_compilation_unit(input); + let mut compilation_unit = match parse_result { + Ok(compilation_unit) => compilation_unit, + Err(diagnostics) => { + return diagnostics; + } + }; + + let mut symbol_table = SymbolTable::new(); + + match compilation_unit.gather_declared_names(&mut symbol_table) { + Ok(_) => {} + Err(diagnostics) => { + return diagnostics; + } + } + + match compilation_unit.check_name_usages(&symbol_table) { + Ok(_) => {} + Err(diagnostics) => { + return diagnostics; + } + } + + match compilation_unit.type_check(&symbol_table) { + Ok(_) => vec![], + Err(diagnostics) => diagnostics, + } + } + + #[test] + fn wrong_return_type() { + let diagnostics = get_diagnostics("fn main() -> String 42 end"); + assert_eq!(diagnostics.len(), 1); + } +}