Refactor to use BinaryExpression.

This commit is contained in:
Jesse Brault 2026-03-15 12:18:38 -05:00
parent 5a123419bd
commit 863c3fef5d
8 changed files with 294 additions and 291 deletions

View File

@ -1,102 +0,0 @@
use crate::ast::expression::Expression;
use crate::ast::ir_builder::IrBuilder;
use crate::diagnostic::Diagnostic;
use crate::ir::ir_add::IrAdd;
use crate::source_range::SourceRange;
use crate::symbol_table::SymbolTable;
use crate::type_info::TypeInfo;
pub struct AddExpression {
lhs: Box<Expression>,
rhs: Box<Expression>,
source_range: SourceRange,
type_info: Option<TypeInfo>,
}
impl AddExpression {
pub fn new(lhs: Expression, rhs: Expression, source_range: SourceRange) -> Self {
Self {
lhs: lhs.into(),
rhs: rhs.into(),
source_range,
type_info: None,
}
}
pub fn lhs(&self) -> &Expression {
&self.lhs
}
pub fn rhs(&self) -> &Expression {
&self.rhs
}
pub fn gather_declared_names(
&mut self,
symbol_table: &mut SymbolTable,
) -> Result<(), Vec<Diagnostic>> {
let diagnostics: Vec<Diagnostic> = [self.lhs.as_mut(), self.rhs.as_mut()]
.iter_mut()
.map(|expression| expression.gather_declared_names(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
if diagnostics.is_empty() {
Ok(())
} else {
Err(diagnostics)
}
}
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
let diagnostics: Vec<Diagnostic> = [self.lhs.as_mut(), self.rhs.as_mut()]
.iter_mut()
.map(|expression| expression.check_name_usages(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
if diagnostics.is_empty() {
Ok(())
} else {
Err(diagnostics)
}
}
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
self.lhs.type_check(symbol_table)?;
self.rhs.type_check(symbol_table)?;
let lhs_type_info = self.lhs.type_info();
let rhs_type_info = self.rhs.type_info();
if lhs_type_info.can_add(rhs_type_info) {
self.type_info = Some(lhs_type_info.add_result(rhs_type_info));
Ok(())
} else {
Err(vec![Diagnostic::new(
&format!("Cannot add {} to {}", rhs_type_info, lhs_type_info),
self.source_range.start(),
self.source_range.end(),
)])
}
}
pub fn to_ir(&self, builder: &mut IrBuilder, symbol_table: &SymbolTable) -> IrAdd {
let lhs_ir_expression = self
.lhs
.to_ir_expression(builder, symbol_table)
.expect("Attempt to add non-expression");
let rhs_ir_expression = self
.rhs
.to_ir_expression(builder, symbol_table)
.expect("Attempt to add non-expression");
IrAdd::new(lhs_ir_expression, rhs_ir_expression)
}
pub fn type_info(&self) -> &TypeInfo {
self.type_info.as_ref().unwrap()
}
pub fn source_range(&self) -> &SourceRange {
&self.source_range
}
}

View File

@ -0,0 +1,212 @@
use crate::ast::expression::Expression;
use crate::ast::ir_builder::IrBuilder;
use crate::diagnostic::Diagnostic;
use crate::error_codes::BINARY_INCOMPATIBLE_TYPES;
use crate::ir::ir_add::IrAdd;
use crate::ir::ir_assign::IrAssign;
use crate::ir::ir_expression::IrExpression;
use crate::ir::ir_operation::IrOperation;
use crate::ir::ir_statement::IrStatement;
use crate::ir::ir_subtract::IrSubtract;
use crate::ir::ir_variable::IrVariable;
use crate::source_range::SourceRange;
use crate::symbol_table::SymbolTable;
use crate::type_info::TypeInfo;
use crate::{diagnostics_result, handle_diagnostic, handle_diagnostics, maybe_return_diagnostics};
use std::cell::RefCell;
use std::rc::Rc;
pub enum BinaryOperation {
Multiply,
Divide,
Add,
Subtract,
}
pub struct BinaryExpression {
lhs: Box<Expression>,
rhs: Box<Expression>,
op: BinaryOperation,
source_range: SourceRange,
type_info: Option<TypeInfo>,
}
impl BinaryExpression {
pub fn new(
lhs: Expression,
rhs: Expression,
op: BinaryOperation,
source_range: SourceRange,
) -> Self {
Self {
lhs: lhs.into(),
rhs: rhs.into(),
op,
source_range,
type_info: None,
}
}
pub fn lhs(&self) -> &Expression {
&self.lhs
}
pub fn rhs(&self) -> &Expression {
&self.rhs
}
pub fn op(&self) -> &BinaryOperation {
&self.op
}
pub fn source_range(&self) -> &SourceRange {
&self.source_range
}
pub fn type_info(&self) -> &TypeInfo {
self.type_info.as_ref().unwrap()
}
pub fn gather_declared_names(
&mut self,
symbol_table: &mut SymbolTable,
) -> Result<(), Vec<Diagnostic>> {
let diagnostics = [&mut self.lhs, &mut self.rhs]
.iter_mut()
.map(|expression| expression.gather_declared_names(symbol_table))
.filter_map(|result| result.err())
.flatten()
.collect::<Vec<Diagnostic>>();
if diagnostics.is_empty() {
Ok(())
} else {
Err(diagnostics)
}
}
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
let diagnostics: Vec<Diagnostic> = [&mut self.lhs, &mut self.rhs]
.iter_mut()
.map(|expression| expression.check_name_usages(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
if diagnostics.is_empty() {
Ok(())
} else {
Err(diagnostics)
}
}
fn check_op(
&mut self,
check: impl Fn(&TypeInfo, &TypeInfo) -> bool,
op_result: impl Fn(&TypeInfo, &TypeInfo) -> TypeInfo,
lazy_diagnostic_message: impl Fn(&TypeInfo, &TypeInfo) -> String,
) -> Result<(), Diagnostic> {
let lhs_type_info = self.lhs.type_info();
let rhs_type_info = self.rhs.type_info();
if check(lhs_type_info, rhs_type_info) {
self.type_info = Some(op_result(lhs_type_info, rhs_type_info));
Ok(())
} else {
let diagnostic = Diagnostic::new(
&lazy_diagnostic_message(lhs_type_info, rhs_type_info),
self.source_range.start(),
self.source_range.end(),
)
.with_primary_label_message("Incompatible types for addition.")
.with_reporter(file!(), line!())
.with_error_code(BINARY_INCOMPATIBLE_TYPES);
Err(diagnostic)
}
}
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
let mut diagnostics: Vec<Diagnostic> = vec![];
handle_diagnostics!(self.lhs.type_check(symbol_table), diagnostics);
handle_diagnostics!(self.rhs.type_check(symbol_table), diagnostics);
maybe_return_diagnostics!(diagnostics);
match &self.op {
BinaryOperation::Multiply => {
todo!()
}
BinaryOperation::Divide => {
todo!()
}
BinaryOperation::Add => {
handle_diagnostic!(
self.check_op(
|lhs, rhs| lhs.can_add(rhs),
|lhs, rhs| lhs.add_result(&rhs),
|lhs, rhs| format!("Incompatible types: cannot add {} to {}.", rhs, lhs)
),
diagnostics
);
}
BinaryOperation::Subtract => {
handle_diagnostic!(
self.check_op(
|lhs, rhs| lhs.can_subtract(rhs),
|lhs, rhs| lhs.subtract_result(rhs),
|lhs, rhs| format!(
"Incompatible types: cannot subtract {} from {}.",
rhs, lhs
)
),
diagnostics
)
}
}
diagnostics_result!(diagnostics)
}
pub fn to_ir_operation(
&self,
builder: &mut IrBuilder,
symbol_table: &SymbolTable,
) -> IrOperation {
let lhs = self
.lhs
.to_ir_expression(builder, symbol_table)
.expect("Attempt to use a non-value expression in binary expression.");
let rhs = self
.rhs
.to_ir_expression(builder, symbol_table)
.expect("Attempt to use a non-value expression in binary expression.");
match self.op {
BinaryOperation::Multiply => {
todo!()
}
BinaryOperation::Divide => {
todo!()
}
BinaryOperation::Add => IrOperation::Add(IrAdd::new(lhs, rhs)),
BinaryOperation::Subtract => IrOperation::Subtract(IrSubtract::new(lhs, rhs)),
}
}
pub fn to_ir_expression(
&self,
builder: &mut IrBuilder,
symbol_table: &SymbolTable,
) -> IrExpression {
let ir_operation = self.to_ir_operation(builder, symbol_table);
let t_var = IrVariable::new_vr(
builder.new_t_var().into(),
builder.current_block().id(),
self.type_info(),
);
let as_rc = Rc::new(RefCell::new(t_var));
let ir_assign = IrAssign::new(as_rc.clone(), ir_operation);
builder
.current_block_mut()
.add_statement(IrStatement::Assign(ir_assign));
IrExpression::Variable(as_rc)
}
}

View File

@ -1,4 +1,4 @@
use crate::ast::add_expression::AddExpression;
use crate::ast::binary_expression::BinaryExpression;
use crate::ast::call::Call;
use crate::ast::double_literal::DoubleLiteral;
use crate::ast::identifier::Identifier;
@ -6,7 +6,6 @@ use crate::ast::integer_literal::IntegerLiteral;
use crate::ast::ir_builder::IrBuilder;
use crate::ast::negative_expression::NegativeExpression;
use crate::ast::string_literal::StringLiteral;
use crate::ast::subtract_expression::SubtractExpression;
use crate::diagnostic::Diagnostic;
use crate::ir::ir_assign::IrAssign;
use crate::ir::ir_expression::IrExpression;
@ -20,8 +19,7 @@ use std::cell::RefCell;
use std::rc::Rc;
pub enum Expression {
Add(AddExpression),
Subtract(SubtractExpression),
Binary(BinaryExpression),
Negative(NegativeExpression),
Call(Call),
Identifier(Identifier),
@ -36,9 +34,8 @@ impl Expression {
symbol_table: &mut SymbolTable,
) -> Result<(), Vec<Diagnostic>> {
match self {
Expression::Add(add_expression) => add_expression.gather_declared_names(symbol_table),
Expression::Subtract(subtract_expression) => {
subtract_expression.gather_declared_names(symbol_table)
Expression::Binary(binary_expression) => {
binary_expression.gather_declared_names(symbol_table)
}
Expression::Negative(negative_expression) => {
negative_expression.gather_declared_names(symbol_table)
@ -53,9 +50,8 @@ impl Expression {
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
match self {
Expression::Add(add_expression) => add_expression.check_name_usages(symbol_table),
Expression::Subtract(subtract_expression) => {
subtract_expression.check_name_usages(symbol_table)
Expression::Binary(binary_expression) => {
binary_expression.check_name_usages(symbol_table)
}
Expression::Negative(negative_expression) => {
negative_expression.check_name_usages(symbol_table)
@ -70,10 +66,7 @@ impl Expression {
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
match self {
Expression::Add(add_expression) => add_expression.type_check(symbol_table),
Expression::Subtract(subtract_expression) => {
subtract_expression.type_check(symbol_table)
}
Expression::Binary(binary_expression) => binary_expression.type_check(symbol_table),
Expression::Negative(negative_expression) => {
negative_expression.type_check(symbol_table)
}
@ -87,8 +80,7 @@ impl Expression {
pub fn type_info(&self) -> &TypeInfo {
match self {
Expression::Add(add_expression) => add_expression.type_info(),
Expression::Subtract(subtract_expression) => subtract_expression.type_info(),
Expression::Binary(binary_expression) => binary_expression.type_info(),
Expression::Negative(negative_expression) => negative_expression.type_info(),
Expression::Call(call) => call.return_type_info(),
Expression::Identifier(identifier) => identifier.type_info(),
@ -100,8 +92,7 @@ impl Expression {
pub fn source_range(&self) -> &SourceRange {
match self {
Expression::Add(additive_expression) => additive_expression.source_range(),
Expression::Subtract(subtract_expression) => subtract_expression.source_range(),
Expression::Binary(binary_expression) => binary_expression.source_range(),
Expression::Negative(negative_expression) => negative_expression.source_range(),
Expression::Call(call) => call.source_range(),
Expression::Identifier(identifier) => identifier.source_range(),
@ -117,6 +108,9 @@ impl Expression {
symbol_table: &SymbolTable,
) -> IrOperation {
match self {
Expression::Binary(binary_expression) => {
binary_expression.to_ir_operation(builder, symbol_table)
}
Expression::Call(call) => IrOperation::Call(call.to_ir(builder, symbol_table)),
Expression::Integer(integer_literal) => {
IrOperation::Load(IrExpression::Int(integer_literal.value()))
@ -130,12 +124,6 @@ impl Expression {
Expression::Identifier(identifier) => {
IrOperation::Load(identifier.expressible_symbol().ir_expression(builder))
}
Expression::Add(additive_expression) => {
IrOperation::Add(additive_expression.to_ir(builder, symbol_table))
}
Expression::Subtract(subtract_expression) => {
IrOperation::Subtract(subtract_expression.to_ir_subtract(builder, symbol_table))
}
Expression::Negative(negative_expression) => {
IrOperation::Load(negative_expression.to_ir(builder, symbol_table))
}
@ -148,6 +136,9 @@ impl Expression {
symbol_table: &SymbolTable,
) -> Option<IrExpression> {
match self {
Expression::Binary(binary_expression) => {
Some(binary_expression.to_ir_expression(builder, symbol_table))
}
Expression::Call(call) => {
let ir_call = call.to_ir(builder, symbol_table);
if matches!(call.return_type_info(), TypeInfo::Void) {
@ -182,23 +173,6 @@ impl Expression {
let expressible_symbol = identifier.expressible_symbol();
Some(expressible_symbol.ir_expression(builder))
}
Expression::Add(additive_expression) => {
let ir_add = additive_expression.to_ir(builder, symbol_table);
let t_var = IrVariable::new_vr(
builder.new_t_var().into(),
builder.current_block().id(),
additive_expression.type_info(),
);
let as_rc = Rc::new(RefCell::new(t_var));
let assign = IrAssign::new(as_rc.clone(), IrOperation::Add(ir_add));
builder
.current_block_mut()
.add_statement(IrStatement::Assign(assign));
Some(IrExpression::Variable(as_rc))
}
Expression::Subtract(subtract_expression) => {
Some(subtract_expression.to_ir_expression(builder, symbol_table))
}
Expression::Negative(negative_expression) => {
Some(negative_expression.to_ir(builder, symbol_table))
}

View File

@ -1,5 +1,5 @@
pub mod add_expression;
pub mod assign_statement;
pub mod binary_expression;
pub mod call;
pub mod class;
pub mod compilation_unit;
@ -22,5 +22,5 @@ pub mod negative_expression;
pub mod parameter;
pub mod statement;
pub mod string_literal;
pub mod subtract_expression;
pub mod type_use;
mod util;

View File

@ -1,135 +0,0 @@
use crate::ast::expression::Expression;
use crate::ast::ir_builder::IrBuilder;
use crate::diagnostic::Diagnostic;
use crate::ir::ir_assign::IrAssign;
use crate::ir::ir_expression::IrExpression;
use crate::ir::ir_operation::IrOperation;
use crate::ir::ir_statement::IrStatement;
use crate::ir::ir_subtract::IrSubtract;
use crate::ir::ir_variable::IrVariable;
use crate::source_range::SourceRange;
use crate::symbol_table::SymbolTable;
use crate::type_info::TypeInfo;
use std::cell::RefCell;
use std::rc::Rc;
pub struct SubtractExpression {
lhs: Box<Expression>,
rhs: Box<Expression>,
source_range: SourceRange,
type_info: Option<TypeInfo>,
}
impl SubtractExpression {
pub fn new(lhs: Expression, rhs: Expression, source_range: SourceRange) -> Self {
Self {
lhs: lhs.into(),
rhs: rhs.into(),
source_range,
type_info: None,
}
}
pub fn lhs(&self) -> &Expression {
&self.lhs
}
pub fn rhs(&self) -> &Expression {
&self.rhs
}
pub fn gather_declared_names(
&mut self,
symbol_table: &mut SymbolTable,
) -> Result<(), Vec<Diagnostic>> {
let diagnostics = [&mut self.lhs, &mut self.rhs]
.iter_mut()
.map(|expression| expression.gather_declared_names(symbol_table))
.filter_map(|result| result.err())
.flatten()
.collect::<Vec<Diagnostic>>();
if diagnostics.is_empty() {
Ok(())
} else {
Err(diagnostics)
}
}
pub fn check_name_usages(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
let diagnostics: Vec<Diagnostic> = [&mut self.lhs, &mut self.rhs]
.iter_mut()
.map(|expression| expression.check_name_usages(symbol_table))
.filter_map(Result::err)
.flatten()
.collect();
if diagnostics.is_empty() {
Ok(())
} else {
Err(diagnostics)
}
}
pub fn type_check(&mut self, symbol_table: &SymbolTable) -> Result<(), Vec<Diagnostic>> {
self.lhs.type_check(symbol_table)?;
self.rhs.type_check(symbol_table)?;
let lhs_type_info = self.lhs.type_info();
let rhs_type_info = self.rhs.type_info();
if lhs_type_info.can_subtract(rhs_type_info) {
self.type_info = Some(lhs_type_info.add_result(rhs_type_info));
Ok(())
} else {
Err(vec![Diagnostic::new(
&format!(
"Incompatible types: cannot subtract {} from {}",
rhs_type_info, lhs_type_info
), // n.b. order
self.lhs.source_range().start(),
self.lhs.source_range().end(),
)])
}
}
pub fn type_info(&self) -> &TypeInfo {
self.type_info.as_ref().unwrap()
}
pub fn source_range(&self) -> &SourceRange {
&self.source_range
}
pub fn to_ir_subtract(
&self,
builder: &mut IrBuilder,
symbol_table: &SymbolTable,
) -> IrSubtract {
let lhs = self
.lhs
.to_ir_expression(builder, symbol_table)
.expect("Attempt to subtract non-expression");
let rhs = self
.rhs
.to_ir_expression(builder, symbol_table)
.expect("Attempt to subtract non-expression");
IrSubtract::new(lhs, rhs)
}
pub fn to_ir_expression(
&self,
builder: &mut IrBuilder,
symbol_table: &SymbolTable,
) -> IrExpression {
let ir_subtract = self.to_ir_subtract(builder, symbol_table);
let t_var = IrVariable::new_vr(
builder.new_t_var().into(),
builder.current_block().id(),
self.type_info(),
);
let as_rc = Rc::new(RefCell::new(t_var));
let assign = IrAssign::new(as_rc.clone(), IrOperation::Subtract(ir_subtract));
builder
.current_block_mut()
.add_statement(IrStatement::Assign(assign));
IrExpression::Variable(as_rc)
}
}

43
dmc-lib/src/ast/util.rs Normal file
View File

@ -0,0 +1,43 @@
#[macro_export]
macro_rules! handle_diagnostic {
( $result: expr, $diagnostics: expr ) => {
match $result {
Ok(_) => {}
Err(diagnostic) => {
$diagnostics.push(diagnostic);
}
}
};
}
#[macro_export]
macro_rules! handle_diagnostics {
( $result: expr, $diagnostics: expr ) => {
match $result {
Ok(_) => {}
Err(mut result_diagnostics) => {
$diagnostics.append(&mut result_diagnostics);
}
}
};
}
#[macro_export]
macro_rules! maybe_return_diagnostics {
( $diagnostics: expr ) => {
if !$diagnostics.is_empty() {
return Err($diagnostics);
}
};
}
#[macro_export]
macro_rules! diagnostics_result {
( $diagnostics: expr ) => {
if $diagnostics.is_empty() {
Ok(())
} else {
Err($diagnostics)
}
};
}

View File

@ -1,5 +1,6 @@
pub type ErrorCode = usize;
pub const BINARY_INCOMPATIBLE_TYPES: ErrorCode = 15;
pub const ASSIGN_MISMATCHED_TYPES: ErrorCode = 16;
pub const ASSIGN_NO_L_VALUE: ErrorCode = 17;
pub const ASSIGN_LHS_IMMUTABLE: ErrorCode = 18;

View File

@ -1,5 +1,5 @@
use crate::ast::add_expression::AddExpression;
use crate::ast::assign_statement::AssignStatement;
use crate::ast::binary_expression::{BinaryExpression, BinaryOperation};
use crate::ast::call::Call;
use crate::ast::class::Class;
use crate::ast::compilation_unit::CompilationUnit;
@ -17,7 +17,6 @@ use crate::ast::negative_expression::NegativeExpression;
use crate::ast::parameter::Parameter;
use crate::ast::statement::Statement;
use crate::ast::string_literal::StringLiteral;
use crate::ast::subtract_expression::SubtractExpression;
use crate::ast::type_use::TypeUse;
use crate::diagnostic::Diagnostic;
use crate::lexer::Lexer;
@ -676,15 +675,24 @@ impl<'a> Parser<'a> {
let rhs = self.prefix_expression()?;
let source_range =
SourceRange::new(result.source_range().start(), rhs.source_range().end());
result = Expression::Add(AddExpression::new(result, rhs, source_range));
result = Expression::Binary(BinaryExpression::new(
result,
rhs,
BinaryOperation::Add,
source_range,
));
}
TokenKind::Minus => {
self.advance(); // minus
let rhs = self.prefix_expression()?;
let source_range =
SourceRange::new(result.source_range().start(), rhs.source_range().end());
result =
Expression::Subtract(SubtractExpression::new(result, rhs, source_range));
result = Expression::Binary(BinaryExpression::new(
result,
rhs,
BinaryOperation::Subtract,
source_range,
));
}
_ => break,
}
@ -1055,14 +1063,15 @@ mod concrete_tests {
fn add_negative() {
let expression = assert_expression("1 + -1");
match expression {
Expression::Add(add_expression) => {
match add_expression.lhs() {
Expression::Binary(binary_expression) => {
assert!(matches!(binary_expression.op(), BinaryOperation::Add));
match binary_expression.lhs() {
Expression::Integer(integer_literal) => {
assert_eq!(integer_literal.value(), 1);
}
_ => panic!("Expected integer literal"),
}
match add_expression.rhs() {
match binary_expression.rhs() {
Expression::Negative(negative_expression) => {
match negative_expression.operand() {
Expression::Integer(integer_literal) => {
@ -1082,14 +1091,15 @@ mod concrete_tests {
fn simple_subtract() {
let expression = assert_expression("1 - 1");
match expression {
Expression::Subtract(subtract_expression) => {
match subtract_expression.lhs() {
Expression::Binary(binary_expression) => {
assert!(matches!(binary_expression.op(), BinaryOperation::Subtract));
match binary_expression.lhs() {
Expression::Integer(integer_literal) => {
assert_eq!(integer_literal.value(), 1);
}
_ => panic!("Expected integer literal"),
}
match subtract_expression.rhs() {
match binary_expression.rhs() {
Expression::Integer(integer_literal) => {
assert_eq!(integer_literal.value(), 1);
}