diff --git a/dmc-lib/src/ast/additive_expression.rs b/dmc-lib/src/ast/additive_expression.rs index 856933c..6e2773c 100644 --- a/dmc-lib/src/ast/additive_expression.rs +++ b/dmc-lib/src/ast/additive_expression.rs @@ -21,6 +21,14 @@ impl AdditiveExpression { } } + pub fn lhs(&self) -> &Expression { + &self.lhs + } + + pub fn rhs(&self) -> &Expression { + &self.rhs + } + pub fn gather_declared_names( &mut self, symbol_table: &mut SymbolTable, diff --git a/dmc-lib/src/ast/expression.rs b/dmc-lib/src/ast/expression.rs index 3ad63da..c8ee0ff 100644 --- a/dmc-lib/src/ast/expression.rs +++ b/dmc-lib/src/ast/expression.rs @@ -3,7 +3,9 @@ use crate::ast::call::Call; use crate::ast::identifier::Identifier; use crate::ast::integer_literal::IntegerLiteral; use crate::ast::ir_builder::IrBuilder; +use crate::ast::negative_expression::NegativeExpression; use crate::ast::string_literal::StringLiteral; +use crate::ast::subtract_expression::SubtractExpression; use crate::diagnostic::Diagnostic; use crate::ir::ir_assign::IrAssign; use crate::ir::ir_expression::IrExpression; @@ -22,6 +24,8 @@ pub enum Expression { String(StringLiteral), Identifier(Identifier), Additive(AdditiveExpression), + Subtract(SubtractExpression), + Negative(NegativeExpression), } impl Expression { @@ -79,6 +83,8 @@ impl Expression { Expression::String(_) => TypeInfo::String, Expression::Identifier(identifier) => identifier.type_info(), Expression::Additive(additive_expression) => additive_expression.type_info(), + Expression::Subtract(subtract_expression) => todo!(), + Expression::Negative(_) => todo!(), } } @@ -89,6 +95,8 @@ impl Expression { Expression::String(string_literal) => string_literal.source_range(), Expression::Identifier(identifier) => identifier.source_range(), Expression::Additive(additive_expression) => additive_expression.source_range(), + Expression::Subtract(subtract_expression) => subtract_expression.source_range(), + Expression::Negative(negative_expression) => negative_expression.source_range(), } } @@ -143,6 +151,8 @@ impl Expression { .add_statement(IrStatement::Assign(assign)); Some(IrExpression::Variable(as_rc)) } + Expression::Subtract(subtract_expression) => todo!(), + Expression::Negative(_) => todo!(), } } } diff --git a/dmc-lib/src/ast/let_statement.rs b/dmc-lib/src/ast/let_statement.rs index b326e02..9e79af9 100644 --- a/dmc-lib/src/ast/let_statement.rs +++ b/dmc-lib/src/ast/let_statement.rs @@ -101,6 +101,8 @@ impl LetStatement { Expression::Additive(additive_expression) => { IrOperation::Add(additive_expression.to_ir(builder, symbol_table)) } + Expression::Subtract(subtract_expression) => todo!(), + Expression::Negative(_) => todo!(), }; let destination_symbol = diff --git a/dmc-lib/src/ast/mod.rs b/dmc-lib/src/ast/mod.rs index 09bb6d8..0f0d6c3 100644 --- a/dmc-lib/src/ast/mod.rs +++ b/dmc-lib/src/ast/mod.rs @@ -11,7 +11,9 @@ pub mod integer_literal; pub mod ir_builder; pub mod let_statement; pub mod module_level_declaration; +pub mod negative_expression; pub mod parameter; pub mod statement; pub mod string_literal; +pub mod subtract_expression; pub mod type_use; diff --git a/dmc-lib/src/ast/negative_expression.rs b/dmc-lib/src/ast/negative_expression.rs new file mode 100644 index 0000000..bc96396 --- /dev/null +++ b/dmc-lib/src/ast/negative_expression.rs @@ -0,0 +1,28 @@ +use crate::ast::expression::Expression; +use crate::source_range::SourceRange; + +pub struct NegativeExpression { + operand: Box, + source_range: SourceRange, +} + +impl NegativeExpression { + pub fn new(operand: Expression, source_range: SourceRange) -> Self { + Self { + operand: operand.into(), + source_range, + } + } + + pub fn source_range(&self) -> &SourceRange { + &self.source_range + } + + pub fn operand(&self) -> &Expression { + &self.operand + } + + pub fn operand_mut(&mut self) -> &mut Expression { + &mut self.operand + } +} diff --git a/dmc-lib/src/ast/subtract_expression.rs b/dmc-lib/src/ast/subtract_expression.rs new file mode 100644 index 0000000..06a06c3 --- /dev/null +++ b/dmc-lib/src/ast/subtract_expression.rs @@ -0,0 +1,30 @@ +use crate::ast::expression::Expression; +use crate::source_range::SourceRange; + +pub struct SubtractExpression { + lhs: Box, + rhs: Box, + source_range: SourceRange, +} + +impl SubtractExpression { + pub fn new(lhs: Expression, rhs: Expression, source_range: SourceRange) -> Self { + Self { + lhs: lhs.into(), + rhs: rhs.into(), + source_range, + } + } + + pub fn lhs(&self) -> &Expression { + &self.lhs + } + + pub fn rhs(&self) -> &Expression { + &self.rhs + } + + pub fn source_range(&self) -> &SourceRange { + &self.source_range + } +} diff --git a/dmc-lib/src/lexer.rs b/dmc-lib/src/lexer.rs index 7acaa43..53df915 100644 --- a/dmc-lib/src/lexer.rs +++ b/dmc-lib/src/lexer.rs @@ -36,6 +36,8 @@ impl<'a> Lexer<'a> { let token = if chunk.starts_with("->") { Token::new(self.position, self.position + 2, TokenKind::RightArrow) + } else if chunk.starts_with("-") { + Token::new(self.position, self.position + 1, TokenKind::Minus) } else if chunk.starts_with("(") { Token::new(self.position, self.position + 1, TokenKind::LeftParentheses) } else if chunk.starts_with(")") { diff --git a/dmc-lib/src/parser.rs b/dmc-lib/src/parser.rs index 42108fd..cb6cedd 100644 --- a/dmc-lib/src/parser.rs +++ b/dmc-lib/src/parser.rs @@ -9,9 +9,11 @@ use crate::ast::identifier::Identifier; use crate::ast::integer_literal::IntegerLiteral; use crate::ast::let_statement::LetStatement; use crate::ast::module_level_declaration::ModuleLevelDeclaration; +use crate::ast::negative_expression::NegativeExpression; use crate::ast::parameter::Parameter; use crate::ast::statement::Statement; use crate::ast::string_literal::StringLiteral; +use crate::ast::subtract_expression::SubtractExpression; use crate::ast::type_use::TypeUse; use crate::diagnostic::Diagnostic; use crate::lexer::Lexer; @@ -24,6 +26,12 @@ pub fn parse_compilation_unit(input: &str) -> Result Result> { + let mut parser = Parser::new(input); + parser.advance(); // get started + parser.expression() +} + struct Parser<'a> { input: &'a str, lexer: Lexer<'a>, @@ -391,24 +399,63 @@ impl<'a> Parser<'a> { } fn additive_expression(&mut self) -> Result> { - let mut result = self.suffix_expression()?; + let mut result = self.prefix_expression()?; while self.current.is_some() { let current = self.get_current(); match current.kind() { TokenKind::Plus => { self.advance(); // plus - let rhs = self.suffix_expression()?; + let rhs = self.prefix_expression()?; let source_range = SourceRange::new(result.source_range().start(), rhs.source_range().end()); result = Expression::Additive(AdditiveExpression::new(result, rhs, source_range)); } + TokenKind::Minus => { + self.advance(); // minus + let rhs = self.prefix_expression()?; + let source_range = + SourceRange::new(result.source_range().start(), rhs.source_range().end()); + result = + Expression::Subtract(SubtractExpression::new(result, rhs, source_range)); + } _ => break, } } Ok(result) } + fn prefix_expression(&mut self) -> Result> { + // first, collect all consecutive operators + let mut operator_tokens = vec![]; + while self.current.is_some() { + let current = self.get_current(); + match current.kind() { + TokenKind::Minus => { + operator_tokens.push(current.clone()); // unfortunately necessary + self.advance(); + } + _ => break, + } + } + + // now go in reverse and build up expressions + // the parser is currently just after the prefix operators, so we need a suffix expression + // as a base + let mut result = self.suffix_expression()?; + while let Some(operator_token) = operator_tokens.pop() { + let source_range = + SourceRange::new(operator_token.start(), result.source_range().end()); + match operator_token.kind() { + TokenKind::Minus => { + result = Expression::Negative(NegativeExpression::new(result, source_range)); + } + _ => unreachable!(), + } + } + Ok(result) + } + fn suffix_expression(&mut self) -> Result> { let mut result = self.expression_base()?; while self.current.is_some() { @@ -453,7 +500,7 @@ impl<'a> Parser<'a> { source_range, ))) } - _ => unreachable!(), + _ => unreachable!("Unreachable token type found: {:?}", current.kind()), } } @@ -539,24 +586,90 @@ mod smoke_tests { fn add_two_numbers() { smoke_test("fn main() 1 + 2 end"); } + + #[test] + fn negative_return() { + smoke_test("fn main() -> Int -1 end"); + } + + #[test] + fn negative_left_add() { + smoke_test("fn main() -> Int -1 + 1 end"); + } + + #[test] + fn negative_right_add() { + smoke_test("fn main() -> Int 1 + -1 end"); + } + + #[test] + fn two_negatives() { + smoke_test("fn main() -> Int -1 + -1 end"); + } + + #[test] + fn minus_positive_number() { + smoke_test("fn main() -> Int 1 - 1 end"); + } + + #[test] + fn minus_negative_number() { + smoke_test("fn main() -> Int 1 - -1 end"); + } } #[cfg(test)] mod concrete_tests { use super::*; - #[test] - fn parses_extern_fn() { - let parse_result = parse_compilation_unit("extern fn println() -> Void"); - let compilation_unit = match parse_result { + fn report_diagnostics(diagnostics: &[Diagnostic]) -> ! { + for diagnostic in diagnostics { + eprintln!("{:?}", diagnostic); + } + panic!(); + } + + fn assert_compilation_unit(input: &str) -> CompilationUnit { + let parse_result = parse_compilation_unit(input); + match parse_result { Ok(compilation_unit) => compilation_unit, Err(diagnostics) => { - for diagnostic in diagnostics { - eprintln!("{:?}", diagnostic); - } - panic!(); + report_diagnostics(&diagnostics); } - }; + } + } + + fn assert_expression(input: &str) -> Expression { + let parse_result = parse_expression(input); + match parse_result { + Ok(expression) => expression, + Err(diagnostics) => { + report_diagnostics(&diagnostics); + } + } + } + + fn assert_function_in<'a>( + compilation_unit: &'a CompilationUnit, + function_name: &str, + ) -> &'a Function { + let declarations = compilation_unit.declarations(); + for declaration in declarations { + match declaration { + ModuleLevelDeclaration::Function(function) => { + if function.declared_name() == function_name { + return function; + } + } + _ => {} + } + } + panic!("Function {} not found", function_name) + } + + #[test] + fn parses_extern_fn() { + let compilation_unit = assert_compilation_unit("extern fn println() -> Void"); let declarations = compilation_unit.declarations(); assert_eq!(declarations.len(), 1); let extern_function = match &declarations[0] { @@ -568,23 +681,8 @@ mod concrete_tests { #[test] fn hello_world() { - let parse_result = parse_compilation_unit("fn main() println(\"Hello, World!\") end"); - let compilation_unit = match parse_result { - Ok(compilation_unit) => compilation_unit, - Err(diagnostics) => { - for diagnostic in &diagnostics { - eprintln!("{:?}", diagnostic) - } - panic!() - } - }; - let declarations = compilation_unit.declarations(); - assert_eq!(declarations.len(), 1); - let function = match &declarations[0] { - ModuleLevelDeclaration::Function(function) => function, - _ => panic!(), - }; - assert_eq!(function.declared_name(), "main"); + let compilation_unit = assert_compilation_unit("fn main() println(\"Hello, World!\") end"); + let function = assert_function_in(&compilation_unit, "main"); let statements = function.statements(); assert_eq!(statements.len(), 1); if let Statement::Expression(expression_statement) = statements[0] { @@ -612,6 +710,69 @@ mod concrete_tests { panic!("Expected expression"); } } + + #[test] + fn negative_expression() { + let expression = assert_expression("-1"); + match expression { + Expression::Negative(negative_expression) => match negative_expression.operand() { + Expression::IntegerLiteral(integer_literal) => { + assert_eq!(integer_literal.value(), 1); + } + _ => panic!("Expected integer literal"), + }, + _ => panic!("Expected negative expression"), + } + } + + #[test] + fn add_negative() { + let expression = assert_expression("1 + -1"); + match expression { + Expression::Additive(add_expression) => { + match add_expression.lhs() { + Expression::IntegerLiteral(integer_literal) => { + assert_eq!(integer_literal.value(), 1); + } + _ => panic!("Expected integer literal"), + } + match add_expression.rhs() { + Expression::Negative(negative_expression) => { + match negative_expression.operand() { + Expression::IntegerLiteral(integer_literal) => { + assert_eq!(integer_literal.value(), 1); + } + _ => panic!("Expected integer literal"), + } + } + _ => panic!("Expected negative expression"), + } + } + _ => panic!("Expected additive expression"), + } + } + + #[test] + fn simple_subtract() { + let expression = assert_expression("1 - 1"); + match expression { + Expression::Subtract(subtract_expression) => { + match subtract_expression.lhs() { + Expression::IntegerLiteral(integer_literal) => { + assert_eq!(integer_literal.value(), 1); + } + _ => panic!("Expected integer literal"), + } + match subtract_expression.rhs() { + Expression::IntegerLiteral(integer_literal) => { + assert_eq!(integer_literal.value(), 1); + } + _ => panic!("Expected integer literal"), + } + } + _ => panic!("Expected subtract expression"), + } + } } #[cfg(test)] diff --git a/dmc-lib/src/token.rs b/dmc-lib/src/token.rs index 531585e..b96f3f7 100644 --- a/dmc-lib/src/token.rs +++ b/dmc-lib/src/token.rs @@ -40,4 +40,5 @@ pub enum TokenKind { Colon, RightArrow, Plus, + Minus, }