diff --git a/dmc-lib/src/ast/ast_to_ir_util.rs b/dmc-lib/src/ast/ast_to_ir_util.rs index df4f4a3..9c0dfc1 100644 --- a/dmc-lib/src/ast/ast_to_ir_util.rs +++ b/dmc-lib/src/ast/ast_to_ir_util.rs @@ -4,9 +4,10 @@ use crate::ir::ir_assign::IrAssign; use crate::ir::ir_expression::IrExpression; use crate::ir::ir_operation::IrOperation; use crate::ir::ir_statement::IrStatement; -use crate::ir::ir_variable::{IrVariable, IrVirtualRegisterVariable}; +use crate::ir::ir_variable::IrVariable; use crate::symbol_table::SymbolTable; use crate::type_info::TypeInfo; +use std::cell::RefCell; use std::rc::Rc; pub fn expression_to_ir_expression( @@ -23,13 +24,17 @@ pub fn expression_to_ir_expression( .add_statement(IrStatement::Call(ir_call)); None } else { - let t_var = IrVirtualRegisterVariable::new(&builder.new_t_var(), call.type_info()); - let as_rc = Rc::new(t_var); + let t_var = IrVariable::new_vr( + builder.new_t_var().into(), + builder.current_block().id(), + call.type_info(), + ); + let as_rc = Rc::new(RefCell::new(t_var)); let assign = IrAssign::new(as_rc.clone(), IrOperation::Call(ir_call)); builder .current_block_mut() .add_statement(IrStatement::Assign(assign)); - Some(IrExpression::Variable(IrVariable::VirtualRegister(as_rc))) + Some(IrExpression::Variable(as_rc)) } } Expression::IntegerLiteral(integer_literal) => { @@ -44,16 +49,17 @@ pub fn expression_to_ir_expression( } Expression::Additive(additive_expression) => { let ir_add = additive_expression.to_ir(builder, symbol_table); - let t_var = IrVirtualRegisterVariable::new( - &builder.new_t_var(), + let t_var = IrVariable::new_vr( + builder.new_t_var().into(), + builder.current_block().id(), additive_expression.type_info(), ); - let as_rc = Rc::new(t_var); + let as_rc = Rc::new(RefCell::new(t_var)); let assign = IrAssign::new(as_rc.clone(), IrOperation::Add(ir_add)); builder .current_block_mut() .add_statement(IrStatement::Assign(assign)); - Some(IrExpression::Variable(IrVariable::VirtualRegister(as_rc))) + Some(IrExpression::Variable(as_rc)) } } } diff --git a/dmc-lib/src/ast/ir_builder.rs b/dmc-lib/src/ast/ir_builder.rs index 25c592b..da2cfb4 100644 --- a/dmc-lib/src/ast/ir_builder.rs +++ b/dmc-lib/src/ast/ir_builder.rs @@ -1,7 +1,6 @@ use crate::ir::ir_block::IrBlock; use crate::ir::ir_parameter::IrParameter; use crate::ir::ir_statement::IrStatement; -use crate::ir::ir_variable::IrVariable; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; @@ -53,6 +52,12 @@ impl IrBuilder { .expect("No current block builder") } + pub fn current_block(&self) -> &IrBlockBuilder { + self.current_block_builder + .as_ref() + .expect("No current block") + } + pub fn finish_block(&mut self) { let builder = self .current_block_builder @@ -82,6 +87,10 @@ impl IrBlockBuilder { } } + pub fn id(&self) -> usize { + self.id + } + pub fn add_statement(&mut self, statement: IrStatement) { self.statements.push(statement); } diff --git a/dmc-lib/src/ast/let_statement.rs b/dmc-lib/src/ast/let_statement.rs index 12131b9..270f835 100644 --- a/dmc-lib/src/ast/let_statement.rs +++ b/dmc-lib/src/ast/let_statement.rs @@ -8,10 +8,11 @@ use crate::ir::ir_assign::IrAssign; use crate::ir::ir_expression::IrExpression; use crate::ir::ir_operation::IrOperation; use crate::ir::ir_statement::IrStatement; -use crate::ir::ir_variable::{IrVariable, IrVirtualRegisterVariable}; +use crate::ir::ir_variable::IrVariable; use crate::source_range::SourceRange; use crate::symbol::{ExpressibleSymbol, VariableSymbol}; use crate::symbol_table::{SymbolInsertError, SymbolTable}; +use std::cell::RefCell; use std::rc::Rc; pub struct LetStatement { @@ -107,9 +108,12 @@ impl LetStatement { let destination_symbol = symbol_table.get_variable_symbol(self.scope_id.unwrap(), &self.declared_name); - let destination_vr_variable = - IrVirtualRegisterVariable::new(self.declared_name(), self.initializer().type_info()); - let as_rc = Rc::new(destination_vr_variable); + let destination_vr_variable = IrVariable::new_vr( + self.declared_name().into(), + builder.current_block().id(), + self.initializer.type_info(), + ); + let as_rc = Rc::new(RefCell::new(destination_vr_variable)); let ir_assign = IrAssign::new(as_rc.clone(), init_operation); destination_symbol.borrow_mut().set_vr_variable(as_rc); diff --git a/dmc-lib/src/ir/ir_add.rs b/dmc-lib/src/ir/ir_add.rs index da41704..13695f9 100644 --- a/dmc-lib/src/ir/ir_add.rs +++ b/dmc-lib/src/ir/ir_add.rs @@ -1,8 +1,8 @@ use crate::ir::ir_expression::IrExpression; -use crate::ir::ir_variable::IrVirtualRegisterVariable; +use crate::ir::ir_variable::IrVrVariableDescriptor; +use crate::ir::register_allocation::VrUser; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::rc::Rc; pub struct IrAdd { left: Box, @@ -16,18 +16,6 @@ impl IrAdd { right: right.into(), } } - - pub fn vr_uses(&self) -> HashSet> { - let mut set = HashSet::new(); - set.extend(self.left.vr_uses()); - set.extend(self.right.vr_uses()); - set - } - - pub fn propagate_spills(&mut self, spills: &HashSet>) { - self.left.propagate_spills(spills); - self.right.propagate_spills(spills); - } } impl Display for IrAdd { @@ -35,3 +23,24 @@ impl Display for IrAdd { write!(f, "{} + {}", self.left, self.right) } } + +impl VrUser for IrAdd { + fn vr_definitions(&self) -> HashSet { + [self.left.as_ref(), self.right.as_ref()] + .iter() + .flat_map(|e| e.vr_definitions()) + .collect() + } + + fn vr_uses(&self) -> HashSet { + [self.left.as_ref(), self.right.as_ref()] + .iter() + .flat_map(|e| e.vr_uses()) + .collect() + } + + fn propagate_spills(&mut self, spills: &HashSet) { + self.left.propagate_spills(spills); + self.right.propagate_spills(spills); + } +} diff --git a/dmc-lib/src/ir/ir_assign.rs b/dmc-lib/src/ir/ir_assign.rs index 6062c93..cc7be49 100644 --- a/dmc-lib/src/ir/ir_assign.rs +++ b/dmc-lib/src/ir/ir_assign.rs @@ -1,47 +1,52 @@ use crate::ir::ir_operation::IrOperation; -use crate::ir::ir_variable::{IrStackVariable, IrVariable, IrVirtualRegisterVariable}; +use crate::ir::ir_variable::{IrVariable, IrVariableDescriptor, IrVrVariableDescriptor}; +use crate::ir::register_allocation::VrUser; +use std::cell::RefCell; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::ops::Deref; use std::rc::Rc; pub struct IrAssign { - destination: Box, + destination: Rc>, initializer: Box, } impl IrAssign { - pub fn new(destination: Rc, initializer: IrOperation) -> Self { + pub fn new(destination: Rc>, initializer: IrOperation) -> Self { Self { - destination: Box::new(IrVariable::VirtualRegister(destination)), + destination, initializer: initializer.into(), } } +} - pub fn vr_definitions(&self) -> HashSet> { - match self.destination.deref() { - IrVariable::VirtualRegister(vr_variable) => { - let mut set = HashSet::new(); - set.insert(vr_variable.clone()); - set +impl VrUser for IrAssign { + fn vr_definitions(&self) -> HashSet { + match self.destination.borrow().descriptor() { + IrVariableDescriptor::VirtualRegister(vr_descriptor) => { + HashSet::from([vr_descriptor.clone()]) } - IrVariable::Stack(_) => HashSet::new(), + IrVariableDescriptor::Stack(_) => HashSet::new(), } } - pub fn vr_uses(&self) -> HashSet> { + fn vr_uses(&self) -> HashSet { self.initializer.vr_uses() } - pub fn propagate_spills(&mut self, spills: &HashSet>) { - self.initializer.propagate_spills(spills); - if let IrVariable::VirtualRegister(vr_variable) = self.destination.deref() { + fn propagate_spills(&mut self, spills: &HashSet) { + let borrowed_destination = self.destination.borrow(); + if let IrVariableDescriptor::VirtualRegister(vr_variable) = + borrowed_destination.descriptor() + { if spills.contains(vr_variable) { - println!("changing vr to stack: {}", vr_variable.name()); - self.destination = Box::new(IrVariable::Stack(IrStackVariable::new( - vr_variable.name(), - vr_variable.type_info().clone(), - ))); + let replacement = IrVariable::new_stack( + vr_variable.name().into(), + vr_variable.block_id(), + borrowed_destination.type_info().clone(), + ); + drop(borrowed_destination); + self.destination.replace(replacement); } } } @@ -52,8 +57,8 @@ impl Display for IrAssign { write!( f, "{}: {} = {}", - self.destination, - self.destination.type_info(), + self.destination.borrow(), + self.destination.borrow().type_info(), self.initializer ) } diff --git a/dmc-lib/src/ir/ir_block.rs b/dmc-lib/src/ir/ir_block.rs index e882cae..93e84c5 100644 --- a/dmc-lib/src/ir/ir_block.rs +++ b/dmc-lib/src/ir/ir_block.rs @@ -1,5 +1,6 @@ use crate::ir::ir_statement::IrStatement; -use crate::ir::ir_variable::IrVirtualRegisterVariable; +use crate::ir::ir_variable::IrVrVariableDescriptor; +use crate::ir::register_allocation::{HasVrUsers, VrUser, registers_and_spills}; use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; @@ -13,10 +14,6 @@ pub struct IrBlock { statements: Vec, } -type LivenessMapByStatement = HashMap>>; -type InterferenceGraph = - HashMap, HashSet>>; - impl IrBlock { pub fn new(id: usize, debug_label: &str, statements: Vec) -> Self { Self { @@ -36,115 +33,15 @@ impl IrBlock { &self.statements } - fn vr_definitions(&self) -> HashSet> { - let mut set = HashSet::new(); - for statement in &self.statements { - set.extend(statement.vr_definitions()); - } - set - } - - fn vr_uses(&self) -> HashSet> { - let mut set = HashSet::new(); - for statement in &self.statements { - set.extend(statement.vr_uses()); - } - set - } - - fn live_in_live_out(&self) -> (LivenessMapByStatement, LivenessMapByStatement) { - let mut live_in: LivenessMapByStatement = HashMap::new(); - let mut live_out: LivenessMapByStatement = HashMap::new(); - - loop { - let mut did_work = false; - for (statement_index, statement) in self.statements.iter().enumerate().rev() { - // init if necessary - if !live_in.contains_key(&statement_index) { - live_in.insert(statement_index, HashSet::new()); - } - if !live_out.contains_key(&statement_index) { - live_out.insert(statement_index, HashSet::new()); - } - - // out (union of successors ins) - // for now, a statement can only have one successor - // this will need to be updated when we add jumps - if let Some(successor_live_in) = live_in.get(&(statement_index + 1)) { - let statement_live_out = live_out.get_mut(&statement_index).unwrap(); - for vr_variable in successor_live_in { - if statement_live_out.insert(vr_variable.clone()) { - did_work = true; - } - } - } - - // in: use(s) U ( out(s) - def(s) ) - let mut new_ins = statement.vr_uses(); - let statement_live_out = live_out.get(&statement_index).unwrap(); - let defs = statement.vr_definitions(); - let rhs = statement_live_out - &defs; - new_ins.extend(rhs); - - let statement_live_in = live_in.get_mut(&statement_index).unwrap(); - for new_in in new_ins { - if statement_live_in.insert(new_in) { - did_work = true; - } - } - } - if !did_work { - break; - } - } - (live_in, live_out) - } - - fn interference_graph(&self) -> InterferenceGraph { - let mut all_vr_variables: HashSet> = HashSet::new(); - for statement in &self.statements { - all_vr_variables.extend(statement.vr_definitions()); - all_vr_variables.extend(statement.vr_uses()); - } - - let mut graph: InterferenceGraph = HashMap::new(); - for variable in all_vr_variables { - graph.insert(variable, HashSet::new()); - } - - let (_, live_out) = self.live_in_live_out(); - - for (statement_index, statement) in self.statements.iter().enumerate() { - let statement_live_out = live_out.get(&statement_index).unwrap(); - for definition_vr_variable in statement.vr_definitions() { - for live_out_variable in statement_live_out { - // we do the following check to avoid adding an edge to itself - if definition_vr_variable != *live_out_variable { - graph - .get_mut(&definition_vr_variable) - .unwrap() - .insert(live_out_variable.clone()); - graph - .get_mut(live_out_variable) - .unwrap() - .insert(definition_vr_variable.clone()); - } - } - } - } - - graph - } - pub fn register_assignments( &mut self, register_count: usize, - ) -> HashMap, usize> { - let mut spills: HashSet> = HashSet::new(); + ) -> HashMap { + let mut spills: HashSet = HashSet::new(); loop { let mut interference_graph = self.interference_graph(); let (registers, new_spills) = - register_assignment::registers_and_spills(&mut interference_graph, register_count); + registers_and_spills(&mut interference_graph, register_count); if spills != new_spills { spills = new_spills; @@ -159,6 +56,31 @@ impl IrBlock { } } +impl HasVrUsers for IrBlock { + fn vr_users(&self) -> Vec<&dyn VrUser> { + self.statements.iter().map(|s| s as &dyn VrUser).collect() + } +} + +impl VrUser for IrBlock { + fn vr_definitions(&self) -> HashSet { + self.statements + .iter() + .flat_map(|s| s.vr_definitions()) + .collect() + } + + fn vr_uses(&self) -> HashSet { + self.statements.iter().flat_map(|s| s.vr_uses()).collect() + } + + fn propagate_spills(&mut self, spills: &HashSet) { + for statement in &mut self.statements { + statement.propagate_spills(spills); + } + } +} + impl Display for IrBlock { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { writeln!(f, " {}:", self.debug_label)?; @@ -226,306 +148,3 @@ mod tests { } } } - -mod register_assignment { - use super::*; - - #[derive(Debug)] - struct WorkItem { - vr: Rc, - edges: HashSet>, - color: bool, - } - - pub fn registers_and_spills( - interference_graph: &mut InterferenceGraph, - k: usize, - ) -> ( - HashMap, usize>, - HashSet>, - ) { - let mut work_stack: Vec = vec![]; - - while !interference_graph.is_empty() { - let next = next_work_item(interference_graph, k); - work_stack.push(next); - } - - // 3. assign colors to registers - let mut rebuilt_graph: InterferenceGraph = HashMap::new(); - let mut register_assignments: HashMap, usize> = - HashMap::new(); - let mut spills: HashSet> = HashSet::new(); - - while let Some(work_item) = work_stack.pop() { - if work_item.color { - assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); - } else if can_optimistically_color(&work_item, &mut register_assignments, k) { - assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); - } else { - // spill - spills.insert(work_item.vr.clone()); - } - } - - (register_assignments, spills) - } - - fn assign_register( - work_item: &WorkItem, - graph: &mut InterferenceGraph, - k: usize, - register_assignments: &mut HashMap, usize>, - ) { - rebuild_vr_and_edges(graph, work_item); - - // find a register which is not yet shared by all outgoing edges' vertices - 'outer: for i in 0..k { - for edge in graph.get_mut(&work_item.vr).unwrap().iter() { - if register_assignments.contains_key(edge) { - let assignment = register_assignments.get(edge).unwrap(); - if *assignment == i { - continue 'outer; - } - } - } - register_assignments.insert(work_item.vr.clone(), i); - break; - } - } - - fn find_vr_lt_k( - interference_graph: &InterferenceGraph, - k: usize, - ) -> Option<&Rc> { - interference_graph.iter().find_map( - |(vr, neighbors)| { - if neighbors.len() < k { Some(vr) } else { None } - }, - ) - } - - fn remove_vr_and_edges( - interference_graph: &mut InterferenceGraph, - vr: &Rc, - ) -> HashSet> { - // first, outgoing - let outgoing_edges = interference_graph.remove(vr).unwrap(); - - // second, incoming - for neighbor in &outgoing_edges { - let neighbor_edges = interference_graph.get_mut(neighbor).unwrap(); - neighbor_edges.remove(vr); - } - - outgoing_edges - } - - fn next_work_item(interference_graph: &mut InterferenceGraph, k: usize) -> WorkItem { - // try to find a node (virtual register) with less than k outgoing edges, and mark as color - // for step 3. - // if not, pick any, and mark as spill for step 3. - let register_lt_k = find_vr_lt_k(interference_graph, k); - if let Some(vr) = register_lt_k { - let vr = vr.clone(); - - // remove edges; save outgoing to work_item - let edges = remove_vr_and_edges(interference_graph, &vr); - - // push to work stack - WorkItem { - vr, - edges, - color: true, - } - } else { - // pick any - let vr = interference_graph.iter().last().unwrap().0.clone(); - - // remove edges - let edges = remove_vr_and_edges(interference_graph, &vr); - - WorkItem { - vr, - edges, - color: false, // spill - } - } - } - - fn rebuild_vr_and_edges(graph: &mut InterferenceGraph, work_item: &WorkItem) { - // init the vertex - graph.insert(work_item.vr.clone(), HashSet::new()); - - // outgoing - for neighbor in &work_item.edges { - // check if neighbor exists in the graph first; if it was marked spill earlier and could - // not optimistically color, it won't be in the graph - if graph.contains_key(neighbor) { - // get outgoing set and insert neighbor - graph - .get_mut(&work_item.vr) - .unwrap() - .insert(neighbor.clone()); - } - } - - // incoming - for neighbor in &work_item.edges { - // like above, neighbor may not have been added because of failure to optimistically - // color - if graph.contains_key(neighbor) { - graph - .get_mut(neighbor) - .unwrap() - .insert(work_item.vr.clone()); - } - } - } - - fn can_optimistically_color( - work_item: &WorkItem, - register_assignments: &HashMap, usize>, - k: usize, - ) -> bool { - // see if we can optimistically color - // find how many assignments have been made for the outgoing edges - // if it's less than k, we can do it - let mut number_of_assigned_edges = 0; - for edge in &work_item.edges { - if register_assignments.contains_key(edge) { - number_of_assigned_edges += 1; - } - } - number_of_assigned_edges < k - } - - #[cfg(test)] - mod tests { - use super::*; - use crate::type_info::TypeInfo; - - fn line_graph() -> InterferenceGraph { - let mut graph: InterferenceGraph = HashMap::new(); - let v0 = Rc::new(IrVirtualRegisterVariable::new("v0", TypeInfo::Integer)); - let v1 = Rc::new(IrVirtualRegisterVariable::new("v1", TypeInfo::Integer)); - let v2 = Rc::new(IrVirtualRegisterVariable::new("v2", TypeInfo::Integer)); - - // v1 -- v0 -- v2 - graph.insert(v0.clone(), HashSet::from([v1.clone(), v2.clone()])); - graph.insert(v1.clone(), HashSet::from([v0.clone()])); - graph.insert(v2.clone(), HashSet::from([v0.clone()])); - graph - } - - fn triangle_graph() -> InterferenceGraph { - let mut graph: InterferenceGraph = HashMap::new(); - let v0 = Rc::new(IrVirtualRegisterVariable::new("v0", TypeInfo::Integer)); - let v1 = Rc::new(IrVirtualRegisterVariable::new("v1", TypeInfo::Integer)); - let v2 = Rc::new(IrVirtualRegisterVariable::new("v2", TypeInfo::Integer)); - - // triangle: each has two edges - // v0 - // | \ - // v1--v2 - graph.insert(v0.clone(), HashSet::from([v1.clone(), v2.clone()])); - graph.insert(v1.clone(), HashSet::from([v0.clone(), v2.clone()])); - graph.insert(v2.clone(), HashSet::from([v0.clone(), v1.clone()])); - graph - } - - fn get_vrs(graph: &InterferenceGraph) -> Vec> { - let v0 = graph.keys().find(|k| k.name() == "v0").unwrap(); - let v1 = graph.keys().find(|k| k.name() == "v1").unwrap(); - let v2 = graph.keys().find(|k| k.name() == "v2").unwrap(); - vec![v0.clone(), v1.clone(), v2.clone()] - } - - #[test] - fn find_vr_lt_k_when_k_2() { - let graph = line_graph(); - let found = find_vr_lt_k(&graph, 2); - assert!(found.is_some()); - assert!(found.unwrap().name() == "v1" || found.unwrap().name() == "v2"); - } - - #[test] - fn find_vr_lt_k_when_k_1() { - let graph = line_graph(); - let found = find_vr_lt_k(&graph, 1); - assert!(found.is_none()); - } - - #[test] - fn remove_edges_v0() { - let mut graph = line_graph(); - let vrs = get_vrs(&graph); - - let v0_outgoing = remove_vr_and_edges(&mut graph, &vrs[0]); - assert!(v0_outgoing.contains(&vrs[1])); - assert!(v0_outgoing.contains(&vrs[2])); - - // check that incoming edges were removed - let v1_outgoing = graph.get(&vrs[1]).unwrap(); - assert!(v1_outgoing.is_empty()); - let v2_outgoing = graph.get(&vrs[2]).unwrap(); - assert!(v2_outgoing.is_empty()); - } - - fn triangle_work_stack_k_2() -> Vec { - let k = 2; - let mut graph = triangle_graph(); - - let mut work_stack = vec![]; - - // run three times, once for each register - work_stack.push(next_work_item(&mut graph, k)); - work_stack.push(next_work_item(&mut graph, k)); - work_stack.push(next_work_item(&mut graph, k)); - - work_stack - } - - #[test] - fn next_work_item_k_2() { - let work_stack = triangle_work_stack_k_2(); - - // the actual edges may be different, depending on the underlying order in the sets - // (HashSet seems to use randomness in order) - // however, the bottommost item must be a spill, and the edge counts must be (from the - // bottom of the stack) 2-1-0 - assert!(!work_stack[0].color); - assert_eq!(work_stack[0].edges.len(), 2); - assert_eq!(work_stack[1].edges.len(), 1); - assert_eq!(work_stack[2].edges.len(), 0); - } - - #[test] - fn rebuild_graph_triangle_k_2() { - let mut work_stack = triangle_work_stack_k_2(); - let mut rebuilt_graph: InterferenceGraph = HashMap::new(); - - // it should be possible to rebuild the graph from the stack, without yet worrying - // about spilling/etc. - while let Some(work_item) = work_stack.pop() { - rebuild_vr_and_edges(&mut rebuilt_graph, &work_item); - } - - // we should have a triangle graph again - let vrs = get_vrs(&rebuilt_graph); - for vr in &vrs { - assert!(rebuilt_graph.contains_key(vr)); - assert_eq!(rebuilt_graph.get(vr).unwrap().len(), 2); - } - } - - #[test] - fn registers_and_spills_triangle_k_2() { - let mut graph = triangle_graph(); - let (registers, spills) = registers_and_spills(&mut graph, 2); - // there should be one spill when k is 2 - assert_eq!(registers.len(), 2); - assert_eq!(spills.len(), 1); - } - } -} diff --git a/dmc-lib/src/ir/ir_call.rs b/dmc-lib/src/ir/ir_call.rs index 6d65149..134d9ae 100644 --- a/dmc-lib/src/ir/ir_call.rs +++ b/dmc-lib/src/ir/ir_call.rs @@ -1,6 +1,6 @@ use crate::ir::ir_expression::IrExpression; -use crate::ir::ir_variable::{IrStackVariable, IrVariable, IrVirtualRegisterVariable}; -use crate::type_info::TypeInfo; +use crate::ir::ir_variable::IrVrVariableDescriptor; +use crate::ir::register_allocation::VrUser; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::rc::Rc; @@ -17,16 +17,18 @@ impl IrCall { arguments, } } +} - pub fn vr_uses(&self) -> HashSet> { - let mut set = HashSet::new(); - for argument in &self.arguments { - set.extend(argument.vr_uses()) - } - set +impl VrUser for IrCall { + fn vr_definitions(&self) -> HashSet { + HashSet::new() } - pub fn propagate_spills(&mut self, spills: &HashSet>) { + fn vr_uses(&self) -> HashSet { + self.arguments.iter().flat_map(|a| a.vr_uses()).collect() + } + + fn propagate_spills(&mut self, spills: &HashSet) { for argument in &mut self.arguments { argument.propagate_spills(spills); } diff --git a/dmc-lib/src/ir/ir_expression.rs b/dmc-lib/src/ir/ir_expression.rs index 7264ba2..9e955cd 100644 --- a/dmc-lib/src/ir/ir_expression.rs +++ b/dmc-lib/src/ir/ir_expression.rs @@ -1,46 +1,79 @@ use crate::ir::ir_parameter::IrParameter; -use crate::ir::ir_variable::{IrStackVariable, IrVariable, IrVirtualRegisterVariable}; +use crate::ir::ir_variable::{ + IrStackVariableDescriptor, IrVariable, IrVariableDescriptor, IrVrVariableDescriptor, +}; +use crate::ir::register_allocation::VrUser; +use std::cell::RefCell; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::rc::Rc; pub enum IrExpression { Parameter(Rc), - Variable(IrVariable), + Variable(Rc>), Int(i32), String(Rc), } -impl IrExpression { - pub fn vr_uses(&self) -> HashSet> { +impl Display for IrExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + IrExpression::Parameter(ir_parameter) => { + write!(f, "{}", ir_parameter) + } + IrExpression::Variable(ir_variable) => { + write!(f, "{}", ir_variable.borrow()) + } + IrExpression::Int(i) => { + write!(f, "{}", i) + } + IrExpression::String(s) => { + write!(f, "\"{}\"", s) + } + } + } +} + +impl VrUser for IrExpression { + fn vr_definitions(&self) -> HashSet { + HashSet::new() + } + + fn vr_uses(&self) -> HashSet { match self { IrExpression::Parameter(_) => HashSet::new(), IrExpression::Variable(ir_variable) => { - let mut set = HashSet::new(); - if let IrVariable::VirtualRegister(vr_variable) = ir_variable { - set.insert(vr_variable.clone()); + if let IrVariableDescriptor::VirtualRegister(vr_variable) = + ir_variable.borrow().descriptor() + { + HashSet::from([vr_variable.clone()]) + } else { + HashSet::new() } - set } IrExpression::Int(_) => HashSet::new(), IrExpression::String(_) => HashSet::new(), } } - pub fn propagate_spills(&mut self, spills: &HashSet>) { + fn propagate_spills(&mut self, spills: &HashSet) { match self { IrExpression::Parameter(_) => { // no-op } IrExpression::Variable(ir_variable) => { - if let IrVariable::VirtualRegister(vr_variable) = ir_variable { - if spills.contains(vr_variable) { - let name = vr_variable.name().to_string(); - let type_info = vr_variable.type_info().clone(); - let _ = std::mem::replace( - ir_variable, - IrVariable::Stack(IrStackVariable::new(&name, type_info)), - ); + if let IrVariableDescriptor::VirtualRegister(vr_variable) = + ir_variable.borrow().descriptor() + { + if spills.contains(&vr_variable) { + ir_variable + .borrow_mut() + .set_descriptor(IrVariableDescriptor::Stack( + IrStackVariableDescriptor::new( + vr_variable.name().into(), + vr_variable.block_id(), + ), + )); } } } @@ -53,22 +86,3 @@ impl IrExpression { } } } - -impl Display for IrExpression { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - IrExpression::Parameter(ir_parameter) => { - write!(f, "{}", ir_parameter) - } - IrExpression::Variable(ir_variable) => { - write!(f, "{}", ir_variable) - } - IrExpression::Int(i) => { - write!(f, "{}", i) - } - IrExpression::String(s) => { - write!(f, "\"{}\"", s) - } - } - } -} diff --git a/dmc-lib/src/ir/ir_function.rs b/dmc-lib/src/ir/ir_function.rs index 494a84f..7610a15 100644 --- a/dmc-lib/src/ir/ir_function.rs +++ b/dmc-lib/src/ir/ir_function.rs @@ -1,6 +1,6 @@ use crate::ir::ir_block::IrBlock; use crate::ir::ir_parameter::IrParameter; -use crate::ir::ir_variable::IrVirtualRegisterVariable; +use crate::ir::ir_variable::IrVrVariableDescriptor; use crate::type_info::TypeInfo; use std::cell::RefCell; use std::collections::HashMap; @@ -32,7 +32,7 @@ impl IrFunction { pub fn register_assignments( &mut self, register_count: usize, - ) -> HashMap, usize> { + ) -> HashMap { self.entry.borrow_mut().register_assignments(register_count) } } diff --git a/dmc-lib/src/ir/ir_operation.rs b/dmc-lib/src/ir/ir_operation.rs index ad447ec..de1bbcd 100644 --- a/dmc-lib/src/ir/ir_operation.rs +++ b/dmc-lib/src/ir/ir_operation.rs @@ -1,10 +1,10 @@ use crate::ir::ir_add::IrAdd; use crate::ir::ir_call::IrCall; use crate::ir::ir_expression::IrExpression; -use crate::ir::ir_variable::IrVirtualRegisterVariable; +use crate::ir::ir_variable::IrVrVariableDescriptor; +use crate::ir::register_allocation::VrUser; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::rc::Rc; pub enum IrOperation { Load(IrExpression), @@ -28,8 +28,16 @@ impl Display for IrOperation { } } -impl IrOperation { - pub fn vr_uses(&self) -> HashSet> { +impl VrUser for IrOperation { + fn vr_definitions(&self) -> HashSet { + match self { + IrOperation::Load(ir_expression) => ir_expression.vr_definitions(), + IrOperation::Add(ir_add) => ir_add.vr_definitions(), + IrOperation::Call(ir_call) => ir_call.vr_definitions(), + } + } + + fn vr_uses(&self) -> HashSet { match self { IrOperation::Load(ir_expression) => ir_expression.vr_uses(), IrOperation::Add(ir_add) => ir_add.vr_uses(), @@ -37,11 +45,17 @@ impl IrOperation { } } - pub fn propagate_spills(&mut self, spills: &HashSet>) { + fn propagate_spills(&mut self, spills: &HashSet) { match self { - IrOperation::Load(ir_expression) => ir_expression.propagate_spills(spills), - IrOperation::Add(ir_add) => ir_add.propagate_spills(spills), - IrOperation::Call(ir_call) => ir_call.propagate_spills(spills), + IrOperation::Load(ir_expression) => { + ir_expression.propagate_spills(spills); + } + IrOperation::Add(ir_add) => { + ir_add.propagate_spills(spills); + } + IrOperation::Call(ir_call) => { + ir_call.propagate_spills(spills); + } } } } diff --git a/dmc-lib/src/ir/ir_return.rs b/dmc-lib/src/ir/ir_return.rs index 95cf92e..caaa754 100644 --- a/dmc-lib/src/ir/ir_return.rs +++ b/dmc-lib/src/ir/ir_return.rs @@ -1,8 +1,8 @@ use crate::ir::ir_expression::IrExpression; -use crate::ir::ir_variable::IrVirtualRegisterVariable; +use crate::ir::ir_variable::IrVrVariableDescriptor; +use crate::ir::register_allocation::VrUser; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::rc::Rc; pub struct IrReturn { value: Option, @@ -12,12 +12,22 @@ impl IrReturn { pub fn new(value: Option) -> Self { Self { value } } +} - pub fn vr_uses(&self) -> HashSet> { - self.value.as_ref().map_or(HashSet::new(), |v| v.vr_uses()) +impl VrUser for IrReturn { + fn vr_definitions(&self) -> HashSet { + HashSet::new() } - pub fn propagate_spills(&mut self, spills: &HashSet>) { + fn vr_uses(&self) -> HashSet { + if let Some(ir_expression) = self.value.as_ref() { + ir_expression.vr_uses() + } else { + HashSet::new() + } + } + + fn propagate_spills(&mut self, spills: &HashSet) { if let Some(ir_expression) = self.value.as_mut() { ir_expression.propagate_spills(spills); } diff --git a/dmc-lib/src/ir/ir_statement.rs b/dmc-lib/src/ir/ir_statement.rs index 8df7d74..e5d16b8 100644 --- a/dmc-lib/src/ir/ir_statement.rs +++ b/dmc-lib/src/ir/ir_statement.rs @@ -1,10 +1,10 @@ use crate::ir::ir_assign::IrAssign; use crate::ir::ir_call::IrCall; use crate::ir::ir_return::IrReturn; -use crate::ir::ir_variable::IrVirtualRegisterVariable; +use crate::ir::ir_variable::IrVrVariableDescriptor; +use crate::ir::register_allocation::VrUser; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::rc::Rc; pub enum IrStatement { Assign(IrAssign), @@ -12,16 +12,16 @@ pub enum IrStatement { Return(IrReturn), } -impl IrStatement { - pub fn vr_definitions(&self) -> HashSet> { +impl VrUser for IrStatement { + fn vr_definitions(&self) -> HashSet { match self { IrStatement::Assign(ir_assign) => ir_assign.vr_definitions(), - IrStatement::Call(_) => HashSet::new(), - IrStatement::Return(_) => HashSet::new(), + IrStatement::Call(ir_call) => ir_call.vr_definitions(), + IrStatement::Return(ir_return) => ir_return.vr_definitions(), } } - pub fn vr_uses(&self) -> HashSet> { + fn vr_uses(&self) -> HashSet { match self { IrStatement::Assign(ir_assign) => ir_assign.vr_uses(), IrStatement::Call(ir_call) => ir_call.vr_uses(), @@ -29,7 +29,7 @@ impl IrStatement { } } - pub fn propagate_spills(&mut self, spills: &HashSet>) { + fn propagate_spills(&mut self, spills: &HashSet) { match self { IrStatement::Assign(ir_assign) => { ir_assign.propagate_spills(spills); diff --git a/dmc-lib/src/ir/ir_variable.rs b/dmc-lib/src/ir/ir_variable.rs index 7911caa..3ad9753 100644 --- a/dmc-lib/src/ir/ir_variable.rs +++ b/dmc-lib/src/ir/ir_variable.rs @@ -1,49 +1,89 @@ use crate::type_info::TypeInfo; use std::fmt::{Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::rc::Rc; -pub enum IrVariable { - VirtualRegister(Rc), - Stack(IrStackVariable), +pub struct IrVariable { + descriptor: IrVariableDescriptor, + type_info: TypeInfo, } impl IrVariable { - pub fn new(name: &str, type_info: TypeInfo) -> IrVariable { - IrVariable::VirtualRegister(Rc::new(IrVirtualRegisterVariable::new(name, type_info))) + pub fn new_vr(name: Rc, block_id: usize, type_info: TypeInfo) -> Self { + Self { + descriptor: IrVariableDescriptor::VirtualRegister(IrVrVariableDescriptor::new( + name, block_id, + )), + type_info, + } + } + + pub fn new_stack(name: Rc, block_id: usize, type_info: TypeInfo) -> Self { + Self { + descriptor: IrVariableDescriptor::Stack(IrStackVariableDescriptor::new(name, block_id)), + type_info, + } } pub fn type_info(&self) -> &TypeInfo { - match self { - IrVariable::VirtualRegister(vr_variable) => vr_variable.type_info(), - IrVariable::Stack(stack_variable) => stack_variable.type_info(), - } + &self.type_info + } + + pub fn descriptor(&self) -> &IrVariableDescriptor { + &self.descriptor + } + + pub fn set_descriptor(&mut self, descriptor: IrVariableDescriptor) { + self.descriptor = descriptor; } } impl Display for IrVariable { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.descriptor) + } +} + +pub enum IrVariableDescriptor { + VirtualRegister(IrVrVariableDescriptor), + Stack(IrStackVariableDescriptor), +} + +impl Display for IrVariableDescriptor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - IrVariable::VirtualRegister(vr_variable) => { + IrVariableDescriptor::VirtualRegister(vr_variable) => { write!(f, "{}", vr_variable) } - IrVariable::Stack(stack_variable) => { + IrVariableDescriptor::Stack(stack_variable) => { write!(f, "{}", stack_variable) } } } } -pub struct IrVirtualRegisterVariable { - name: Rc, - type_info: TypeInfo, +impl IrVariableDescriptor { + pub fn name(&self) -> &str { + match self { + IrVariableDescriptor::VirtualRegister(vr_variable) => vr_variable.name(), + IrVariableDescriptor::Stack(stack_variable) => stack_variable.name(), + } + } } -impl IrVirtualRegisterVariable { - pub fn new(name: &str, type_info: TypeInfo) -> Self { +#[derive(Clone, Hash, PartialEq, Eq)] +pub struct IrVrVariableDescriptor { + name: Rc, + block_id: usize, + assigned_register: Option, +} + +impl IrVrVariableDescriptor { + pub fn new(name: Rc, block_id: usize) -> Self { Self { - name: name.into(), - type_info, + name, + block_id, + assigned_register: None, } } @@ -51,59 +91,45 @@ impl IrVirtualRegisterVariable { &self.name } - pub fn type_info(&self) -> &TypeInfo { - &self.type_info + pub fn block_id(&self) -> usize { + self.block_id } } -impl Display for IrVirtualRegisterVariable { +impl Display for IrVrVariableDescriptor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } } -impl Debug for IrVirtualRegisterVariable { +impl Debug for IrVrVariableDescriptor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } } -impl Eq for IrVirtualRegisterVariable {} - -impl PartialEq for IrVirtualRegisterVariable { - fn eq(&self, other: &Self) -> bool { - self.name == other.name - } -} - -impl Hash for IrVirtualRegisterVariable { - fn hash(&self, state: &mut H) { - self.name.hash(state); - } -} - -pub struct IrStackVariable { +pub struct IrStackVariableDescriptor { name: Rc, - type_info: TypeInfo, - offset: Option, + block_id: usize, + offset: Option, } -impl IrStackVariable { - pub fn new(name: &str, type_info: TypeInfo) -> Self { +impl IrStackVariableDescriptor { + pub fn new(name: Rc, block_id: usize) -> Self { Self { - name: name.into(), - type_info, + name, + block_id, offset: None, } } - pub fn type_info(&self) -> &TypeInfo { - &self.type_info + pub fn name(&self) -> &str { + &self.name } } -impl Display for IrStackVariable { +impl Display for IrStackVariableDescriptor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}_b{}", self.name, self.block_id) } } diff --git a/dmc-lib/src/ir/mod.rs b/dmc-lib/src/ir/mod.rs index af5c0c4..69806eb 100644 --- a/dmc-lib/src/ir/mod.rs +++ b/dmc-lib/src/ir/mod.rs @@ -9,3 +9,4 @@ pub mod ir_parameter; pub mod ir_return; pub mod ir_statement; pub mod ir_variable; +mod register_allocation; diff --git a/dmc-lib/src/ir/register_allocation.rs b/dmc-lib/src/ir/register_allocation.rs new file mode 100644 index 0000000..cae0c9e --- /dev/null +++ b/dmc-lib/src/ir/register_allocation.rs @@ -0,0 +1,404 @@ +use crate::ir::ir_variable::IrVrVariableDescriptor; +use std::collections::{HashMap, HashSet}; + +pub type InterferenceGraph = HashMap>; +pub type LivenessMap = HashMap>; + +pub trait HasVrUsers { + fn vr_users(&self) -> Vec<&dyn VrUser>; + + fn live_in_live_out(&self) -> (LivenessMap, LivenessMap) { + let mut live_in: LivenessMap = HashMap::new(); + let mut live_out: LivenessMap = HashMap::new(); + + loop { + let mut did_work = false; + for (block_index, statement) in self.vr_users().iter().enumerate().rev() { + // init if necessary + if !live_in.contains_key(&block_index) { + live_in.insert(block_index, HashSet::new()); + } + if !live_out.contains_key(&block_index) { + live_out.insert(block_index, HashSet::new()); + } + + // out (union of successors ins) + // for now, a statement can only have one successor + // this will need to be updated when we add jumps + if let Some(successor_live_in) = live_in.get(&(block_index + 1)) { + let statement_live_out = live_out.get_mut(&block_index).unwrap(); + for vr_variable in successor_live_in { + if statement_live_out.insert(vr_variable.clone()) { + did_work = true; + } + } + } + + // in: use(s) U ( out(s) - def(s) ) + let mut new_ins = statement + .vr_uses() + .iter() + .map(|u| (*u).clone()) + .collect::>(); + let statement_live_out = live_out.get(&block_index).unwrap(); + let defs = statement + .vr_definitions() + .iter() + .map(|d| (*d).clone()) + .collect::>(); + let rhs = statement_live_out - &defs; + new_ins.extend(rhs); + + let statement_live_in = live_in.get_mut(&block_index).unwrap(); + for new_in in new_ins { + if statement_live_in.insert(new_in) { + did_work = true; + } + } + } + if !did_work { + break; + } + } + (live_in, live_out) + } + + fn interference_graph(&self) -> InterferenceGraph { + let mut all_vr_variables: HashSet = HashSet::new(); + + for vr_user in self.vr_users() { + all_vr_variables.extend(vr_user.vr_definitions()); + all_vr_variables.extend(vr_user.vr_uses()); + } + + let mut graph: InterferenceGraph = HashMap::new(); + for variable in all_vr_variables { + graph.insert(variable, HashSet::new()); + } + + let (_, live_out) = self.live_in_live_out(); + + for (index, vr_user) in self.vr_users().iter().enumerate() { + let user_live_in = live_out.get(&index).unwrap(); + for definition_vr_variable in vr_user.vr_definitions() { + for live_out_variable in user_live_in { + // we do the following check to avoid adding an edge to itself + if definition_vr_variable.ne(live_out_variable) { + graph + .get_mut(&definition_vr_variable) + .unwrap() + .insert(live_out_variable.clone()); + graph + .get_mut(live_out_variable) + .unwrap() + .insert(definition_vr_variable.clone()); + } + } + } + } + + graph + } +} + +pub trait VrUser { + fn vr_definitions(&self) -> HashSet; + fn vr_uses(&self) -> HashSet; + fn propagate_spills(&mut self, spills: &HashSet); +} + +#[derive(Debug)] +struct WorkItem { + vr: IrVrVariableDescriptor, + edges: HashSet, + color: bool, +} + +pub fn registers_and_spills( + interference_graph: &mut InterferenceGraph, + k: usize, +) -> ( + HashMap, + HashSet, +) { + let mut work_stack: Vec = vec![]; + + while !interference_graph.is_empty() { + work_stack.push(next_work_item(interference_graph, k)); + } + + // 3. assign colors to registers + let mut rebuilt_graph: InterferenceGraph = HashMap::new(); + let mut register_assignments: HashMap = HashMap::new(); + let mut spills: HashSet = HashSet::new(); + + while let Some(work_item) = work_stack.pop() { + if work_item.color { + assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); + } else if can_optimistically_color(&work_item, &mut register_assignments, k) { + assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); + } else { + // spill + spills.insert(work_item.vr.clone()); + } + } + + (register_assignments, spills) +} + +fn assign_register( + work_item: &WorkItem, + graph: &mut InterferenceGraph, + k: usize, + register_assignments: &mut HashMap, +) { + rebuild_vr_and_edges(graph, work_item); + + // find a register which is not yet shared by all outgoing edges' vertices + 'outer: for i in 0..k { + for edge in graph.get_mut(&work_item.vr).unwrap().iter() { + if register_assignments.contains_key(edge) { + let assignment = register_assignments.get(edge).unwrap(); + if *assignment == i { + continue 'outer; + } + } + } + register_assignments.insert(work_item.vr.clone(), i); + break; + } +} + +fn find_vr_lt_k( + interference_graph: &InterferenceGraph, + k: usize, +) -> Option<&IrVrVariableDescriptor> { + interference_graph.iter().find_map( + |(vr, neighbors)| { + if neighbors.len() < k { Some(vr) } else { None } + }, + ) +} + +fn remove_vr_and_edges( + interference_graph: &mut InterferenceGraph, + vr: &IrVrVariableDescriptor, +) -> HashSet { + // first, outgoing + let outgoing_edges = interference_graph.remove(vr).unwrap(); + + // second, incoming + for neighbor in &outgoing_edges { + let neighbor_edges = interference_graph.get_mut(neighbor).unwrap(); + neighbor_edges.remove(vr); + } + + outgoing_edges +} + +fn next_work_item(interference_graph: &mut InterferenceGraph, k: usize) -> WorkItem { + // try to find a node (virtual register) with less than k outgoing edges, and mark as color + // for step 3. + // if not, pick any, and mark as spill for step 3. + let register_lt_k = find_vr_lt_k(interference_graph, k); + if let Some(vr) = register_lt_k { + let vr = vr.clone(); + + // remove edges; save outgoing to work_item + let edges = remove_vr_and_edges(interference_graph, &vr); + + // push to work stack + WorkItem { + vr, + edges, + color: true, + } + } else { + // pick any + let vr = interference_graph.iter().last().unwrap().0.clone(); + + // remove edges + let edges = remove_vr_and_edges(interference_graph, &vr); + + WorkItem { + vr, + edges, + color: false, // spill + } + } +} + +fn rebuild_vr_and_edges(graph: &mut InterferenceGraph, work_item: &WorkItem) { + // init the vertex + graph.insert(work_item.vr.clone(), HashSet::new()); + + // outgoing + for neighbor in &work_item.edges { + // check if neighbor exists in the graph first; if it was marked spill earlier and could + // not optimistically color, it won't be in the graph + if graph.contains_key(neighbor) { + // get outgoing set and insert neighbor + graph + .get_mut(&work_item.vr) + .unwrap() + .insert(neighbor.clone()); + } + } + + // incoming + for neighbor in &work_item.edges { + // like above, neighbor may not have been added because of failure to optimistically + // color + if graph.contains_key(neighbor) { + graph + .get_mut(neighbor) + .unwrap() + .insert(work_item.vr.clone()); + } + } +} + +fn can_optimistically_color( + work_item: &WorkItem, + register_assignments: &HashMap, + k: usize, +) -> bool { + // see if we can optimistically color + // find how many assignments have been made for the outgoing edges + // if it's less than k, we can do it + let mut number_of_assigned_edges = 0; + for edge in &work_item.edges { + if register_assignments.contains_key(edge) { + number_of_assigned_edges += 1; + } + } + number_of_assigned_edges < k +} + +#[cfg(test)] +mod tests { + use super::*; + + fn line_graph() -> InterferenceGraph { + let mut graph: InterferenceGraph = HashMap::new(); + let v0 = IrVrVariableDescriptor::new("v0".into(), 0); + let v1 = IrVrVariableDescriptor::new("v1".into(), 0); + let v2 = IrVrVariableDescriptor::new("v2".into(), 0); + + // v1 -- v0 -- v2 + graph.insert(v0.clone(), HashSet::from([v1.clone(), v2.clone()])); + graph.insert(v1.clone(), HashSet::from([v0.clone()])); + graph.insert(v2.clone(), HashSet::from([v0.clone()])); + graph + } + + fn triangle_graph() -> InterferenceGraph { + let mut graph: InterferenceGraph = HashMap::new(); + let v0 = IrVrVariableDescriptor::new("v0".into(), 0); + let v1 = IrVrVariableDescriptor::new("v1".into(), 0); + let v2 = IrVrVariableDescriptor::new("v2".into(), 0); + + // triangle: each has two edges + // v0 + // | \ + // v1--v2 + graph.insert(v0.clone(), HashSet::from([v1.clone(), v2.clone()])); + graph.insert(v1.clone(), HashSet::from([v0.clone(), v2.clone()])); + graph.insert(v2.clone(), HashSet::from([v0.clone(), v1.clone()])); + graph + } + + fn get_vrs(graph: &InterferenceGraph) -> Vec { + let v0 = graph.keys().find(|k| k.name() == "v0").unwrap(); + let v1 = graph.keys().find(|k| k.name() == "v1").unwrap(); + let v2 = graph.keys().find(|k| k.name() == "v2").unwrap(); + vec![v0.clone(), v1.clone(), v2.clone()] + } + + #[test] + fn find_vr_lt_k_when_k_2() { + let graph = line_graph(); + let found = find_vr_lt_k(&graph, 2); + assert!(found.is_some()); + assert!(found.unwrap().name() == "v1" || found.unwrap().name() == "v2"); + } + + #[test] + fn find_vr_lt_k_when_k_1() { + let graph = line_graph(); + let found = find_vr_lt_k(&graph, 1); + assert!(found.is_none()); + } + + #[test] + fn remove_edges_v0() { + let mut graph = line_graph(); + let vrs = get_vrs(&graph); + + let v0_outgoing = remove_vr_and_edges(&mut graph, &vrs[0]); + assert!(v0_outgoing.contains(&vrs[1])); + assert!(v0_outgoing.contains(&vrs[2])); + + // check that incoming edges were removed + let v1_outgoing = graph.get(&vrs[1]).unwrap(); + assert!(v1_outgoing.is_empty()); + let v2_outgoing = graph.get(&vrs[2]).unwrap(); + assert!(v2_outgoing.is_empty()); + } + + fn triangle_work_stack_k_2() -> Vec { + let k = 2; + let mut graph = triangle_graph(); + + let mut work_stack = vec![]; + + // run three times, once for each register + work_stack.push(next_work_item(&mut graph, k)); + work_stack.push(next_work_item(&mut graph, k)); + work_stack.push(next_work_item(&mut graph, k)); + + work_stack + } + + #[test] + fn next_work_item_k_2() { + let work_stack = triangle_work_stack_k_2(); + + // the actual edges may be different, depending on the underlying order in the sets + // (HashSet seems to use randomness in order) + // however, the bottommost item must be a spill, and the edge counts must be (from the + // bottom of the stack) 2-1-0 + assert!(!work_stack[0].color); + assert_eq!(work_stack[0].edges.len(), 2); + assert_eq!(work_stack[1].edges.len(), 1); + assert_eq!(work_stack[2].edges.len(), 0); + } + + #[test] + fn rebuild_graph_triangle_k_2() { + let mut work_stack = triangle_work_stack_k_2(); + let mut rebuilt_graph: InterferenceGraph = HashMap::new(); + + // it should be possible to rebuild the graph from the stack, without yet worrying + // about spilling/etc. + while let Some(work_item) = work_stack.pop() { + rebuild_vr_and_edges(&mut rebuilt_graph, &work_item); + } + + // we should have a triangle graph again + let vrs = get_vrs(&rebuilt_graph); + for vr in &vrs { + assert!(rebuilt_graph.contains_key(vr)); + assert_eq!(rebuilt_graph.get(vr).unwrap().len(), 2); + } + } + + #[test] + fn registers_and_spills_triangle_k_2() { + let mut graph = triangle_graph(); + let (registers, spills) = registers_and_spills(&mut graph, 2); + // there should be one spill when k is 2 + assert_eq!(registers.len(), 2); + assert_eq!(spills.len(), 1); + } +} diff --git a/dmc-lib/src/symbol.rs b/dmc-lib/src/symbol.rs index d51dbbc..fbecf5e 100644 --- a/dmc-lib/src/symbol.rs +++ b/dmc-lib/src/symbol.rs @@ -1,6 +1,6 @@ use crate::ir::ir_expression::IrExpression; use crate::ir::ir_parameter::IrParameter; -use crate::ir::ir_variable::{IrVariable, IrVirtualRegisterVariable}; +use crate::ir::ir_variable::IrVariable; use crate::type_info::TypeInfo; use std::cell::RefCell; use std::rc::Rc; @@ -100,7 +100,7 @@ impl ParameterSymbol { pub struct VariableSymbol { name: Rc, type_info: Option, - vr_variable: Option>, + vr_variable: Option>>, #[deprecated] register: Option, @@ -134,11 +134,11 @@ impl VariableSymbol { .expect("TypeInfo not initialized. Did you type check?") } - pub fn set_vr_variable(&mut self, ir_variable: Rc) { + pub fn set_vr_variable(&mut self, ir_variable: Rc>) { self.vr_variable = Some(ir_variable); } - pub fn vr_variable(&self) -> &Rc { + pub fn vr_variable(&self) -> &Rc> { self.vr_variable .as_ref() .expect("ir_variable not yet initialized") @@ -168,9 +168,9 @@ impl ExpressibleSymbol { ExpressibleSymbol::Parameter(parameter_symbol) => { IrExpression::Parameter(parameter_symbol.borrow().ir_parameter().clone()) } - ExpressibleSymbol::Variable(variable_symbol) => IrExpression::Variable( - IrVariable::VirtualRegister(variable_symbol.borrow().vr_variable().clone()), - ), + ExpressibleSymbol::Variable(variable_symbol) => { + IrExpression::Variable(variable_symbol.borrow().vr_variable().clone()) + } } } }