From 5d565ccf91bc573089bf1cec3bddf5819ed59306 Mon Sep 17 00:00:00 2001 From: Jesse Brault Date: Fri, 6 Mar 2026 20:16:00 -0600 Subject: [PATCH] Fix register allocation algorithm. --- dm/src/main.rs | 8 +- dmc-lib/src/ir/ir_add.rs | 5 + dmc-lib/src/ir/ir_assign.rs | 2 + dmc-lib/src/ir/ir_block.rs | 450 +++++++++++++++++++++++---------- dmc-lib/src/ir/ir_function.rs | 9 +- dmc-lib/src/ir/ir_operation.rs | 10 +- 6 files changed, 344 insertions(+), 140 deletions(-) diff --git a/dm/src/main.rs b/dm/src/main.rs index fc4e65c..97c3f45 100644 --- a/dm/src/main.rs +++ b/dm/src/main.rs @@ -25,6 +25,9 @@ struct Cli { #[arg(long)] show_ir: bool, + + #[arg(long, default_value = "8")] + register_count: usize, } fn main() { @@ -58,8 +61,9 @@ fn main() { for declaration in compilation_unit.declarations() { if let ModuleLevelDeclaration::Function(function) = declaration { let mut ir_function = function.to_ir(&symbol_table); - ir_function.assign_registers(); - println!("{}", ir_function) + let register_assignments = ir_function.register_assignments(args.register_count); + println!("{}", ir_function); + println!("{:?}", register_assignments); } } } diff --git a/dmc-lib/src/ir/ir_add.rs b/dmc-lib/src/ir/ir_add.rs index 8bb0189..da41704 100644 --- a/dmc-lib/src/ir/ir_add.rs +++ b/dmc-lib/src/ir/ir_add.rs @@ -23,6 +23,11 @@ impl IrAdd { set.extend(self.right.vr_uses()); set } + + pub fn propagate_spills(&mut self, spills: &HashSet>) { + self.left.propagate_spills(spills); + self.right.propagate_spills(spills); + } } impl Display for IrAdd { diff --git a/dmc-lib/src/ir/ir_assign.rs b/dmc-lib/src/ir/ir_assign.rs index e8bd4fb..6062c93 100644 --- a/dmc-lib/src/ir/ir_assign.rs +++ b/dmc-lib/src/ir/ir_assign.rs @@ -34,8 +34,10 @@ impl IrAssign { } pub fn propagate_spills(&mut self, spills: &HashSet>) { + self.initializer.propagate_spills(spills); if let IrVariable::VirtualRegister(vr_variable) = self.destination.deref() { if spills.contains(vr_variable) { + println!("changing vr to stack: {}", vr_variable.name()); self.destination = Box::new(IrVariable::Stack(IrStackVariable::new( vr_variable.name(), vr_variable.type_info().clone(), diff --git a/dmc-lib/src/ir/ir_block.rs b/dmc-lib/src/ir/ir_block.rs index 6f4aaab..e882cae 100644 --- a/dmc-lib/src/ir/ir_block.rs +++ b/dmc-lib/src/ir/ir_block.rs @@ -14,6 +14,8 @@ pub struct IrBlock { } type LivenessMapByStatement = HashMap>>; +type InterferenceGraph = + HashMap, HashSet>>; impl IrBlock { pub fn new(id: usize, debug_label: &str, statements: Vec) -> Self { @@ -98,19 +100,14 @@ impl IrBlock { (live_in, live_out) } - fn interference_graph( - &self, - ) -> HashMap, HashSet>> { + fn interference_graph(&self) -> InterferenceGraph { let mut all_vr_variables: HashSet> = HashSet::new(); for statement in &self.statements { all_vr_variables.extend(statement.vr_definitions()); all_vr_variables.extend(statement.vr_uses()); } - let mut graph: HashMap< - Rc, - HashSet>, - > = HashMap::new(); + let mut graph: InterferenceGraph = HashMap::new(); for variable in all_vr_variables { graph.insert(variable, HashSet::new()); } @@ -139,10 +136,16 @@ impl IrBlock { graph } - pub fn assign_registers(&mut self) { + pub fn register_assignments( + &mut self, + register_count: usize, + ) -> HashMap, usize> { let mut spills: HashSet> = HashSet::new(); loop { - let (registers, new_spills) = register_assignment::assign_registers(self); + let mut interference_graph = self.interference_graph(); + let (registers, new_spills) = + register_assignment::registers_and_spills(&mut interference_graph, register_count); + if spills != new_spills { spills = new_spills; // mutate all IrVirtualRegisters to constituent statements @@ -150,8 +153,7 @@ impl IrBlock { statement.propagate_spills(&spills); } } else { - println!("{:?}", registers); - break; + return registers; } } } @@ -185,6 +187,46 @@ impl Display for IrBlock { } } +#[cfg(test)] +mod tests { + use crate::ast::module_level_declaration::ModuleLevelDeclaration; + use crate::parser::parse_compilation_unit; + use crate::symbol_table::SymbolTable; + + #[test] + fn overlapping_assignments_bug_when_k_2() { + let mut compilation_unit = parse_compilation_unit( + " + fn main() + let a = 1 + let b = 2 + let c = 3 + let x = a + b + c + end + ", + ) + .unwrap(); + let mut symbol_table = SymbolTable::new(); + compilation_unit.gather_declared_names(&mut symbol_table); + compilation_unit.check_name_usages(&mut symbol_table); + compilation_unit.type_check(&mut symbol_table); + + let main = compilation_unit + .declarations() + .iter() + .find(|d| matches!(d, ModuleLevelDeclaration::Function(_))) + .unwrap(); + + if let ModuleLevelDeclaration::Function(main) = main { + let mut main_ir = main.to_ir(&symbol_table); + let register_assignments = main_ir.register_assignments(2); + assert_eq!(register_assignments.len(), 4); + } else { + unreachable!() + } + } +} + mod register_assignment { use super::*; @@ -195,148 +237,51 @@ mod register_assignment { color: bool, } - pub fn assign_registers( - block: &IrBlock, + pub fn registers_and_spills( + interference_graph: &mut InterferenceGraph, + k: usize, ) -> ( HashMap, usize>, HashSet>, ) { - let k = 8; + let mut work_stack: Vec = vec![]; + + while !interference_graph.is_empty() { + let next = next_work_item(interference_graph, k); + work_stack.push(next); + } + + // 3. assign colors to registers + let mut rebuilt_graph: InterferenceGraph = HashMap::new(); + let mut register_assignments: HashMap, usize> = + HashMap::new(); let mut spills: HashSet> = HashSet::new(); - loop { - // 1. get interference graph - let mut interference_graph = block.interference_graph(); - let mut work_stack: Vec = vec![]; - - loop { - // 2. coloring by simplification - // try to find a node (virtual register) with less than k outgoing edges, - // and mark as color - // if not, pick any, and mark as spill for step 3 - let register_lt_k = interference_graph.iter().find_map(|(vr, neighbors)| { - if neighbors.len() < k { Some(vr) } else { None } - }); - if let Some(vr) = register_lt_k { - let vr = vr.clone(); - // remove both outgoing and incoming edges; save either set for WorkItem - // first, outgoing: - let outgoing_edges = interference_graph.remove(&vr).unwrap(); - - // second, incoming - interference_graph.iter_mut().for_each(|(_, neighbors)| { - neighbors.remove(&vr); - }); - - // push to work stack - work_stack.push(WorkItem { - vr, - edges: outgoing_edges, - color: true, - }) - } else { - // pick any - let vr = interference_graph.iter().last().unwrap().0.clone(); - - // first, outgoing - let outgoing_edges = interference_graph.remove(&vr).unwrap(); - - // second, incoming - interference_graph.iter_mut().for_each(|(_, neighbors)| { - neighbors.remove(&vr); - }); - - work_stack.push(WorkItem { - vr, - edges: outgoing_edges, - color: false, // spill - }); - } - - if interference_graph.is_empty() { - break; - } - } - - // 3. assign colors to registers - let mut rebuilt_graph: HashMap< - Rc, - HashSet>, - > = HashMap::new(); - let mut register_assignments: HashMap, usize> = - HashMap::new(); - let mut new_spills: HashSet> = HashSet::new(); - - while let Some(work_item) = work_stack.pop() { - if work_item.color { - assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); - } else { - // first, see if we can optimistically color - // find how many assignments have been made for the outgoing edges - // if it's less than k, we can do it - let mut number_of_assigned_edges = 0; - for edge in &work_item.edges { - if register_assignments.contains_key(edge) { - number_of_assigned_edges += 1; - } - } - - if number_of_assigned_edges < k { - // optimistically color - assign_register( - &work_item, - &mut rebuilt_graph, - k, - &mut register_assignments, - ); - } else { - // spill - new_spills.insert(work_item.vr.clone()); - } - } - } - - if spills.eq(&new_spills) { - return (register_assignments, spills); + while let Some(work_item) = work_stack.pop() { + if work_item.color { + assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); + } else if can_optimistically_color(&work_item, &mut register_assignments, k) { + assign_register(&work_item, &mut rebuilt_graph, k, &mut register_assignments); } else { - spills = new_spills; + // spill + spills.insert(work_item.vr.clone()); } } + + (register_assignments, spills) } fn assign_register( work_item: &WorkItem, - rebuilt_graph: &mut HashMap< - Rc, - HashSet>, - >, + graph: &mut InterferenceGraph, k: usize, register_assignments: &mut HashMap, usize>, ) { - let this_vertex_vr = &work_item.vr; - // init the vertex - rebuilt_graph.insert(this_vertex_vr.clone(), HashSet::new()); - - // add edges, both outgoing and incoming - let neighbors = rebuilt_graph.get_mut(this_vertex_vr).unwrap(); - // outgoing - for edge in &work_item.edges { - neighbors.insert(edge.clone()); - } - // incoming - for neighbor in neighbors.clone() { - if rebuilt_graph.contains_key(&neighbor) { - rebuilt_graph - .get_mut(&neighbor) - .unwrap() - .insert(this_vertex_vr.clone()); - } - } + rebuild_vr_and_edges(graph, work_item); // find a register which is not yet shared by all outgoing edges' vertices - // I think the bug is somewhere here 'outer: for i in 0..k { - for edge in rebuilt_graph.get_mut(this_vertex_vr).unwrap().iter() { + for edge in graph.get_mut(&work_item.vr).unwrap().iter() { if register_assignments.contains_key(edge) { let assignment = register_assignments.get(edge).unwrap(); if *assignment == i { @@ -344,8 +289,243 @@ mod register_assignment { } } } - register_assignments.insert(this_vertex_vr.clone(), i); + register_assignments.insert(work_item.vr.clone(), i); break; } } + + fn find_vr_lt_k( + interference_graph: &InterferenceGraph, + k: usize, + ) -> Option<&Rc> { + interference_graph.iter().find_map( + |(vr, neighbors)| { + if neighbors.len() < k { Some(vr) } else { None } + }, + ) + } + + fn remove_vr_and_edges( + interference_graph: &mut InterferenceGraph, + vr: &Rc, + ) -> HashSet> { + // first, outgoing + let outgoing_edges = interference_graph.remove(vr).unwrap(); + + // second, incoming + for neighbor in &outgoing_edges { + let neighbor_edges = interference_graph.get_mut(neighbor).unwrap(); + neighbor_edges.remove(vr); + } + + outgoing_edges + } + + fn next_work_item(interference_graph: &mut InterferenceGraph, k: usize) -> WorkItem { + // try to find a node (virtual register) with less than k outgoing edges, and mark as color + // for step 3. + // if not, pick any, and mark as spill for step 3. + let register_lt_k = find_vr_lt_k(interference_graph, k); + if let Some(vr) = register_lt_k { + let vr = vr.clone(); + + // remove edges; save outgoing to work_item + let edges = remove_vr_and_edges(interference_graph, &vr); + + // push to work stack + WorkItem { + vr, + edges, + color: true, + } + } else { + // pick any + let vr = interference_graph.iter().last().unwrap().0.clone(); + + // remove edges + let edges = remove_vr_and_edges(interference_graph, &vr); + + WorkItem { + vr, + edges, + color: false, // spill + } + } + } + + fn rebuild_vr_and_edges(graph: &mut InterferenceGraph, work_item: &WorkItem) { + // init the vertex + graph.insert(work_item.vr.clone(), HashSet::new()); + + // outgoing + for neighbor in &work_item.edges { + // check if neighbor exists in the graph first; if it was marked spill earlier and could + // not optimistically color, it won't be in the graph + if graph.contains_key(neighbor) { + // get outgoing set and insert neighbor + graph + .get_mut(&work_item.vr) + .unwrap() + .insert(neighbor.clone()); + } + } + + // incoming + for neighbor in &work_item.edges { + // like above, neighbor may not have been added because of failure to optimistically + // color + if graph.contains_key(neighbor) { + graph + .get_mut(neighbor) + .unwrap() + .insert(work_item.vr.clone()); + } + } + } + + fn can_optimistically_color( + work_item: &WorkItem, + register_assignments: &HashMap, usize>, + k: usize, + ) -> bool { + // see if we can optimistically color + // find how many assignments have been made for the outgoing edges + // if it's less than k, we can do it + let mut number_of_assigned_edges = 0; + for edge in &work_item.edges { + if register_assignments.contains_key(edge) { + number_of_assigned_edges += 1; + } + } + number_of_assigned_edges < k + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::type_info::TypeInfo; + + fn line_graph() -> InterferenceGraph { + let mut graph: InterferenceGraph = HashMap::new(); + let v0 = Rc::new(IrVirtualRegisterVariable::new("v0", TypeInfo::Integer)); + let v1 = Rc::new(IrVirtualRegisterVariable::new("v1", TypeInfo::Integer)); + let v2 = Rc::new(IrVirtualRegisterVariable::new("v2", TypeInfo::Integer)); + + // v1 -- v0 -- v2 + graph.insert(v0.clone(), HashSet::from([v1.clone(), v2.clone()])); + graph.insert(v1.clone(), HashSet::from([v0.clone()])); + graph.insert(v2.clone(), HashSet::from([v0.clone()])); + graph + } + + fn triangle_graph() -> InterferenceGraph { + let mut graph: InterferenceGraph = HashMap::new(); + let v0 = Rc::new(IrVirtualRegisterVariable::new("v0", TypeInfo::Integer)); + let v1 = Rc::new(IrVirtualRegisterVariable::new("v1", TypeInfo::Integer)); + let v2 = Rc::new(IrVirtualRegisterVariable::new("v2", TypeInfo::Integer)); + + // triangle: each has two edges + // v0 + // | \ + // v1--v2 + graph.insert(v0.clone(), HashSet::from([v1.clone(), v2.clone()])); + graph.insert(v1.clone(), HashSet::from([v0.clone(), v2.clone()])); + graph.insert(v2.clone(), HashSet::from([v0.clone(), v1.clone()])); + graph + } + + fn get_vrs(graph: &InterferenceGraph) -> Vec> { + let v0 = graph.keys().find(|k| k.name() == "v0").unwrap(); + let v1 = graph.keys().find(|k| k.name() == "v1").unwrap(); + let v2 = graph.keys().find(|k| k.name() == "v2").unwrap(); + vec![v0.clone(), v1.clone(), v2.clone()] + } + + #[test] + fn find_vr_lt_k_when_k_2() { + let graph = line_graph(); + let found = find_vr_lt_k(&graph, 2); + assert!(found.is_some()); + assert!(found.unwrap().name() == "v1" || found.unwrap().name() == "v2"); + } + + #[test] + fn find_vr_lt_k_when_k_1() { + let graph = line_graph(); + let found = find_vr_lt_k(&graph, 1); + assert!(found.is_none()); + } + + #[test] + fn remove_edges_v0() { + let mut graph = line_graph(); + let vrs = get_vrs(&graph); + + let v0_outgoing = remove_vr_and_edges(&mut graph, &vrs[0]); + assert!(v0_outgoing.contains(&vrs[1])); + assert!(v0_outgoing.contains(&vrs[2])); + + // check that incoming edges were removed + let v1_outgoing = graph.get(&vrs[1]).unwrap(); + assert!(v1_outgoing.is_empty()); + let v2_outgoing = graph.get(&vrs[2]).unwrap(); + assert!(v2_outgoing.is_empty()); + } + + fn triangle_work_stack_k_2() -> Vec { + let k = 2; + let mut graph = triangle_graph(); + + let mut work_stack = vec![]; + + // run three times, once for each register + work_stack.push(next_work_item(&mut graph, k)); + work_stack.push(next_work_item(&mut graph, k)); + work_stack.push(next_work_item(&mut graph, k)); + + work_stack + } + + #[test] + fn next_work_item_k_2() { + let work_stack = triangle_work_stack_k_2(); + + // the actual edges may be different, depending on the underlying order in the sets + // (HashSet seems to use randomness in order) + // however, the bottommost item must be a spill, and the edge counts must be (from the + // bottom of the stack) 2-1-0 + assert!(!work_stack[0].color); + assert_eq!(work_stack[0].edges.len(), 2); + assert_eq!(work_stack[1].edges.len(), 1); + assert_eq!(work_stack[2].edges.len(), 0); + } + + #[test] + fn rebuild_graph_triangle_k_2() { + let mut work_stack = triangle_work_stack_k_2(); + let mut rebuilt_graph: InterferenceGraph = HashMap::new(); + + // it should be possible to rebuild the graph from the stack, without yet worrying + // about spilling/etc. + while let Some(work_item) = work_stack.pop() { + rebuild_vr_and_edges(&mut rebuilt_graph, &work_item); + } + + // we should have a triangle graph again + let vrs = get_vrs(&rebuilt_graph); + for vr in &vrs { + assert!(rebuilt_graph.contains_key(vr)); + assert_eq!(rebuilt_graph.get(vr).unwrap().len(), 2); + } + } + + #[test] + fn registers_and_spills_triangle_k_2() { + let mut graph = triangle_graph(); + let (registers, spills) = registers_and_spills(&mut graph, 2); + // there should be one spill when k is 2 + assert_eq!(registers.len(), 2); + assert_eq!(spills.len(), 1); + } + } } diff --git a/dmc-lib/src/ir/ir_function.rs b/dmc-lib/src/ir/ir_function.rs index 471a23d..494a84f 100644 --- a/dmc-lib/src/ir/ir_function.rs +++ b/dmc-lib/src/ir/ir_function.rs @@ -1,7 +1,9 @@ use crate::ir::ir_block::IrBlock; use crate::ir::ir_parameter::IrParameter; +use crate::ir::ir_variable::IrVirtualRegisterVariable; use crate::type_info::TypeInfo; use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Display; use std::rc::Rc; @@ -27,8 +29,11 @@ impl IrFunction { } } - pub fn assign_registers(&mut self) { - self.entry.borrow_mut().assign_registers(); + pub fn register_assignments( + &mut self, + register_count: usize, + ) -> HashMap, usize> { + self.entry.borrow_mut().register_assignments(register_count) } } diff --git a/dmc-lib/src/ir/ir_operation.rs b/dmc-lib/src/ir/ir_operation.rs index 99cd47a..ad447ec 100644 --- a/dmc-lib/src/ir/ir_operation.rs +++ b/dmc-lib/src/ir/ir_operation.rs @@ -1,7 +1,7 @@ use crate::ir::ir_add::IrAdd; use crate::ir::ir_call::IrCall; use crate::ir::ir_expression::IrExpression; -use crate::ir::ir_variable::{IrVariable, IrVirtualRegisterVariable}; +use crate::ir::ir_variable::IrVirtualRegisterVariable; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::rc::Rc; @@ -36,4 +36,12 @@ impl IrOperation { IrOperation::Call(ir_call) => ir_call.vr_uses(), } } + + pub fn propagate_spills(&mut self, spills: &HashSet>) { + match self { + IrOperation::Load(ir_expression) => ir_expression.propagate_spills(spills), + IrOperation::Add(ir_add) => ir_add.propagate_spills(spills), + IrOperation::Call(ir_call) => ir_call.propagate_spills(spills), + } + } }