From 55f7af0fa98fa386c4f292812dde33d4549da967 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 16:21:40 +0100 Subject: [PATCH 01/17] Migrate post order --- crates/cubecl-opt/Cargo.toml | 9 ++-- crates/cubecl-opt/src/analyses/base.rs | 42 +++++++++++++++++++ crates/cubecl-opt/src/analyses/mod.rs | 4 ++ crates/cubecl-opt/src/analyses/post_order.rs | 23 ++++++++++ crates/cubecl-opt/src/debug.rs | 4 -- crates/cubecl-opt/src/lib.rs | 29 +++---------- crates/cubecl-opt/src/passes/dead_code.rs | 11 +++-- .../src/passes/inlined_if_to_select.rs | 8 ++-- crates/cubecl-opt/src/passes/liveness.rs | 4 +- 9 files changed, 92 insertions(+), 42 deletions(-) create mode 100644 crates/cubecl-opt/src/analyses/base.rs create mode 100644 crates/cubecl-opt/src/analyses/mod.rs create mode 100644 crates/cubecl-opt/src/analyses/post_order.rs diff --git a/crates/cubecl-opt/Cargo.toml b/crates/cubecl-opt/Cargo.toml index 752333df3..2587641ab 100644 --- a/crates/cubecl-opt/Cargo.toml +++ b/crates/cubecl-opt/Cargo.toml @@ -2,8 +2,8 @@ authors = ["Genna Wingert"] categories = ["algorithms"] description = "Compiler optimizations for CubeCL" -keywords = ["gpu", "compiler"] edition = "2021" +keywords = ["gpu", "compiler"] license.workspace = true name = "cubecl-opt" readme.workspace = true @@ -11,11 +11,7 @@ repository = "https://github.com/tracel-ai/cubecl/tree/main/cubecl-opt" version.workspace = true [features] -default = [ - "std", - "cubecl-common/default", - "cubecl-core/default", -] +default = ["std", "cubecl-common/default", "cubecl-core/default"] std = ["cubecl-common/std", "cubecl-core/std"] [dependencies] @@ -27,3 +23,4 @@ num = "0.4" petgraph = { version = "0.6" } smallvec = { version = "1", features = ["union", "const_generics"] } stable-vec = { version = "0.4" } +type-map = { version = "0.5" } diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs new file mode 100644 index 000000000..9f1d12e18 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -0,0 +1,42 @@ +use std::{any::Any, cell::RefCell, rc::Rc}; + +use type_map::TypeMap; + +use crate::Optimizer; + +pub trait Analysis { + fn init(opt: &mut Optimizer) -> Self; +} + +#[derive(Default, Clone, Debug)] +pub struct Analyses { + cache: Rc>, +} + +impl Analyses { + pub fn get(&self, opt: &mut Optimizer) -> Rc { + let analysis = self.cache.borrow().get::>().cloned(); + if let Some(analysis) = analysis { + analysis + } else { + let analysis = Rc::new(A::init(opt)); + self.cache.borrow_mut().insert(analysis.clone()); + analysis + } + } + + pub fn invalidate(&self) { + self.cache.borrow_mut().remove::>(); + } +} + +impl Optimizer { + pub fn analysis(&mut self) -> Rc { + let analyses = self.analyses.clone(); + analyses.get(self) + } + + pub fn invalidate_analysis(&self) { + self.analyses.invalidate::(); + } +} diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs new file mode 100644 index 000000000..20d86695f --- /dev/null +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -0,0 +1,4 @@ +mod base; +pub mod post_order; + +pub use base::*; diff --git a/crates/cubecl-opt/src/analyses/post_order.rs b/crates/cubecl-opt/src/analyses/post_order.rs new file mode 100644 index 000000000..6b5890157 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/post_order.rs @@ -0,0 +1,23 @@ +use crate::NodeIndex; +use petgraph::visit::{DfsPostOrder, Walker}; + +use super::Analysis; + +pub struct PostOrder(Vec); + +impl Analysis for PostOrder { + fn init(opt: &mut crate::Optimizer) -> Self { + let po = DfsPostOrder::new(&opt.program.graph, opt.entry()); + PostOrder(po.iter(&opt.program.graph).collect()) + } +} + +impl PostOrder { + pub fn forward(&self) -> Vec { + self.0.clone() + } + + pub fn reverse(&self) -> Vec { + self.0.clone().into_iter().rev().collect() + } +} diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index 5c6321020..93269e5fb 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -16,10 +16,6 @@ const DEBUG_GVN: bool = false; /// Debug display for the program state. impl Display for Optimizer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let post_order = self.post_order.iter().map(|it| format!("bb{}", it.index())); - let post_order = post_order.collect::>(); - writeln!(f, "Post Order: {}", post_order.join(", "))?; - writeln!(f)?; f.write_str("Slices:\n")?; for (var_id, slice) in self.program.slices.iter() { let end_op = slice.end_op.as_ref().map(|it| format!("{it}")); diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 876cecf47..e9eefa3f2 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -25,12 +25,13 @@ use std::{ cell::RefCell, - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, VecDeque}, ops::{Deref, DerefMut}, rc::Rc, sync::atomic::{AtomicUsize, Ordering}, }; +use analyses::Analyses; use cubecl_core::{ ir::{self as core, Branch, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -49,6 +50,7 @@ use passes::{ }; use petgraph::{prelude::StableDiGraph, visit::EdgeRef, Direction}; +mod analyses; mod block; mod control_flow; mod debug; @@ -143,8 +145,6 @@ struct Range { pub struct Optimizer { /// The overall program state program: Program, - /// The post order of the graph for traversal - post_order: Vec, /// The current block while parsing current_block: Option, /// The current loop's break target @@ -158,6 +158,8 @@ pub struct Optimizer { /// The execution mode, `Unchecked` skips bounds check optimizations. pub(crate) mode: ExecutionMode, pub(crate) gvn: Rc>, + + analyses: Rc, } impl Default for Optimizer { @@ -170,8 +172,8 @@ impl Default for Optimizer { root_scope: Scope::root(), cube_dim: Default::default(), mode: Default::default(), - post_order: Default::default(), gvn: Default::default(), + analyses: Default::default(), } } } @@ -195,7 +197,6 @@ impl Optimizer { fn run_opt(&mut self, expand: Scope) { self.parse_graph(expand); self.split_critical_edges(); - self.determine_postorder(self.entry(), &mut HashSet::new()); self.analyze_liveness(); self.apply_pre_ssa_passes(); self.exempt_index_assign_locals(); @@ -243,24 +244,6 @@ impl Optimizer { } } - fn determine_postorder(&mut self, block: NodeIndex, visited: &mut HashSet) { - for successor in self.successors(block) { - if !visited.contains(&successor) { - visited.insert(successor); - self.determine_postorder(successor, visited); - } - } - self.post_order.push(block); - } - - pub fn post_order(&self) -> Vec { - self.post_order.clone() - } - - pub fn reverse_post_order(&self) -> Vec { - self.post_order.iter().rev().copied().collect() - } - fn apply_pre_ssa_passes(&mut self) { // Currently only one pre-ssa pass, but might add more let mut passes = vec![CompositeMerge]; diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index d2d1f4fb7..df33988de 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -7,7 +7,10 @@ use std::{ use cubecl_core::ir::{ConstantScalarValue, Instruction, Operation, VariableKind}; use petgraph::{graph::NodeIndex, visit::EdgeRef}; -use crate::{visit_noop, AtomicCounter, BasicBlock, BlockUse, ControlFlow, Optimizer}; +use crate::{ + analyses::post_order::PostOrder, visit_noop, AtomicCounter, BasicBlock, BlockUse, ControlFlow, + Optimizer, +}; use super::OptimizerPass; @@ -143,7 +146,7 @@ fn search_dead_blocks(opt: &mut Optimizer) -> bool { opt.program.remove_edge(edge); } opt.program.remove_node(block); - opt.post_order.retain(|it| *it != block); + opt.invalidate_analysis::(); return true; } } @@ -203,7 +206,7 @@ impl OptimizerPass for MergeBlocks { } fn merge_blocks(opt: &mut Optimizer) -> bool { - for block_idx in opt.reverse_post_order() { + for block_idx in opt.analysis::().reverse() { let successors = opt.successors(block_idx); if successors.len() == 1 && can_merge(opt, block_idx, successors[0]) { let mut new_block = BasicBlock::default(); @@ -237,7 +240,7 @@ fn merge_blocks(opt: &mut Optimizer) -> bool { } *opt.program.node_weight_mut(block_idx).unwrap() = new_block; opt.program.remove_node(successors[0]); - opt.post_order.retain(|it| *it != successors[0]); + opt.invalidate_analysis::(); update_references(opt, successors[0], block_idx); return true; } diff --git a/crates/cubecl-opt/src/passes/inlined_if_to_select.rs b/crates/cubecl-opt/src/passes/inlined_if_to_select.rs index 514c14c10..5e6eb0251 100644 --- a/crates/cubecl-opt/src/passes/inlined_if_to_select.rs +++ b/crates/cubecl-opt/src/passes/inlined_if_to_select.rs @@ -3,7 +3,10 @@ use std::mem::take; use cubecl_core::ir::{Instruction, Operator, Select}; use petgraph::{graph::NodeIndex, visit::EdgeRef}; -use crate::{passes::update_references, AtomicCounter, ControlFlow, Optimizer}; +use crate::{ + analyses::post_order::PostOrder, passes::update_references, AtomicCounter, ControlFlow, + Optimizer, +}; use super::OptimizerPass; @@ -79,8 +82,7 @@ impl OptimizerPass for EmptyBranchToSelect { opt.program.remove_node(then); opt.program.remove_node(or_else); opt.program.remove_node(merge); - opt.post_order - .retain(|it| *it != then && *it != or_else && *it != merge); + opt.invalidate_analysis::(); for merge_successor in merge_successors { opt.program.add_edge(block, merge_successor, ()); } diff --git a/crates/cubecl-opt/src/passes/liveness.rs b/crates/cubecl-opt/src/passes/liveness.rs index 478fd58d6..d662b96f4 100644 --- a/crates/cubecl-opt/src/passes/liveness.rs +++ b/crates/cubecl-opt/src/passes/liveness.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; use petgraph::graph::NodeIndex; -use crate::Optimizer; +use crate::{analyses::post_order::PostOrder, Optimizer}; #[derive(Clone)] struct BlockSets { @@ -19,7 +19,7 @@ impl Optimizer { /// Do a conservative block level liveness analysis pub fn analyze_liveness(&mut self) { let mut state = State { - worklist: VecDeque::from(self.post_order()), + worklist: VecDeque::from(self.analysis::().forward()), block_sets: HashMap::new(), }; while let Some(block) = state.worklist.pop_front() { From 6f4512ef7333f74ceded36a460dfb0655980a154 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 16:36:13 +0100 Subject: [PATCH 02/17] Simplify dead block elimination --- crates/cubecl-opt/src/passes/dead_code.rs | 22 ++++------------ crates/cubecl-spirv/src/debug.rs | 32 +++++++++++------------ 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index df33988de..e70ee8fad 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -132,26 +132,14 @@ pub struct EliminateDeadBlocks; impl OptimizerPass for EliminateDeadBlocks { fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { - while search_dead_blocks(opt) { - changes.inc(); - } - } -} - -fn search_dead_blocks(opt: &mut Optimizer) -> bool { - for block in opt.node_ids() { - if block != opt.entry() && opt.predecessors(block).is_empty() { - let edges: Vec<_> = opt.program.edges(block).map(|it| it.id()).collect(); - for edge in edges { - opt.program.remove_edge(edge); + let post_order = opt.analysis::().forward(); + for node in opt.node_ids() { + if !post_order.contains(&node) { + opt.program.remove_node(node); + changes.inc(); } - opt.program.remove_node(block); - opt.invalidate_analysis::(); - return true; } } - - false } /// Eliminates invalid phi nodes left over from other optimizations like branch elimination. diff --git a/crates/cubecl-spirv/src/debug.rs b/crates/cubecl-spirv/src/debug.rs index 7717d28ef..a05c671d0 100644 --- a/crates/cubecl-spirv/src/debug.rs +++ b/crates/cubecl-spirv/src/debug.rs @@ -437,23 +437,23 @@ impl SpirvCompiler { for inlined_at in &debug_info.definitions.inlined_at { self.declare_inlined_at(inlined_at); } - } - // Declare entry - let entry_name = self.debug_info().name_str.clone(); - let entry_def = self.definitions().functions[&entry_name]; - let args = self.debug_string(""); - let signature = self.debug_string(SIGNATURE); - self.void_debug( - None, - Instructions::DebugEntryPoint, - [ - entry_def.id, - entry_def.source.compilation_unit, - signature, - args, - ], - ); + // Declare entry + let entry_name = self.debug_info().name_str.clone(); + let entry_def = self.definitions().functions[&entry_name]; + let args = self.debug_string(""); + let signature = self.debug_string(SIGNATURE); + self.void_debug( + None, + Instructions::DebugEntryPoint, + [ + entry_def.id, + entry_def.source.compilation_unit, + signature, + args, + ], + ); + } } fn declare_debug_function(&mut self, function: &FunctionDefinition) { From ace99a2df95258e1dedb4b450da02a8e46c9a6d8 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 17:08:47 +0100 Subject: [PATCH 03/17] Migrate liveness --- crates/cubecl-opt/src/analyses/base.rs | 4 + crates/cubecl-opt/src/analyses/liveness.rs | 105 +++++++++++++++++++++ crates/cubecl-opt/src/analyses/mod.rs | 1 + crates/cubecl-opt/src/block.rs | 13 +-- crates/cubecl-opt/src/control_flow.rs | 2 +- crates/cubecl-opt/src/debug.rs | 9 +- crates/cubecl-opt/src/lib.rs | 9 +- crates/cubecl-opt/src/passes/dead_code.rs | 7 +- crates/cubecl-opt/src/passes/liveness.rs | 74 --------------- crates/cubecl-opt/src/passes/mod.rs | 1 - crates/cubecl-opt/src/phi_frontiers.rs | 50 +++++----- 11 files changed, 149 insertions(+), 126 deletions(-) create mode 100644 crates/cubecl-opt/src/analyses/liveness.rs delete mode 100644 crates/cubecl-opt/src/passes/liveness.rs diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs index 9f1d12e18..731cfee48 100644 --- a/crates/cubecl-opt/src/analyses/base.rs +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -25,6 +25,10 @@ impl Analyses { } } + pub fn try_get(&self) -> Option> { + self.cache.borrow().get().cloned() + } + pub fn invalidate(&self) { self.cache.borrow_mut().remove::>(); } diff --git a/crates/cubecl-opt/src/analyses/liveness.rs b/crates/cubecl-opt/src/analyses/liveness.rs new file mode 100644 index 000000000..ccf01d125 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/liveness.rs @@ -0,0 +1,105 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use petgraph::graph::NodeIndex; + +use crate::{analyses::post_order::PostOrder, Optimizer}; + +use super::Analysis; + +pub struct Liveness { + live_vars: HashMap>, +} + +#[derive(Clone)] +struct BlockSets { + gen: HashSet<(u16, u8)>, + kill: HashSet<(u16, u8)>, +} + +struct State { + worklist: VecDeque, + block_sets: HashMap, +} + +impl Analysis for Liveness { + fn init(opt: &mut Optimizer) -> Self { + let mut this = Self::empty(opt); + this.analyze_liveness(opt); + this + } +} + +impl Liveness { + pub fn empty(opt: &Optimizer) -> Self { + let live_vars = opt + .node_ids() + .iter() + .map(|it| (*it, HashSet::new())) + .collect(); + Self { live_vars } + } + + pub fn at_block(&self, block: NodeIndex) -> &HashSet<(u16, u8)> { + &self.live_vars[&block] + } + + pub fn is_dead(&self, node: NodeIndex, var: (u16, u8)) -> bool { + !self.at_block(node).contains(&var) + } + + /// Do a conservative block level liveness analysis + pub fn analyze_liveness(&mut self, opt: &mut Optimizer) { + let mut state = State { + worklist: VecDeque::from(opt.analysis::().forward()), + block_sets: HashMap::new(), + }; + while let Some(block) = state.worklist.pop_front() { + self.analyze_block(opt, block, &mut state); + } + } + + fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) { + let BlockSets { gen, kill } = block_sets(opt, block, state); + + let mut live_vars = gen.clone(); + + for successor in opt.successors(block) { + let successor = &self.live_vars[&successor]; + live_vars.extend(successor.difference(kill)); + } + + if live_vars != self.live_vars[&block] { + state.worklist.extend(opt.predecessors(block)); + self.live_vars.insert(block, live_vars); + } + } +} + +fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets { + let block_sets = state.block_sets.entry(block); + block_sets.or_insert_with(|| calculate_block_sets(opt, block)) +} + +fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets { + let mut gen = HashSet::new(); + let mut kill = HashSet::new(); + + let ops = opt.program[block].ops.clone(); + + for op in ops.borrow_mut().values_mut().rev() { + // Reads must be tracked after writes + opt.visit_out(&mut op.out, |opt, var| { + if let Some(id) = opt.local_variable_id(var) { + kill.insert(id); + gen.remove(&id); + } + }); + opt.visit_operation(&mut op.operation, |opt, var| { + if let Some(id) = opt.local_variable_id(var) { + gen.insert(id); + } + }); + } + + BlockSets { gen, kill } +} diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs index 20d86695f..75d60ba6a 100644 --- a/crates/cubecl-opt/src/analyses/mod.rs +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -1,4 +1,5 @@ mod base; +pub mod liveness; pub mod post_order; pub use base::*; diff --git a/crates/cubecl-opt/src/block.rs b/crates/cubecl-opt/src/block.rs index 4ff2a957e..d220ea64e 100644 --- a/crates/cubecl-opt/src/block.rs +++ b/crates/cubecl-opt/src/block.rs @@ -4,7 +4,7 @@ use cubecl_core::ir::{Instruction, Variable}; use petgraph::graph::NodeIndex; use stable_vec::StableVec; -use crate::{version::PhiInstruction, ControlFlow, Optimizer, Program}; +use crate::{version::PhiInstruction, ControlFlow, Optimizer}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BlockUse { @@ -21,8 +21,6 @@ pub struct BasicBlock { pub phi_nodes: Rc>>, /// The variables written to by this block. Only set during the SSA transformation. pub(crate) writes: HashSet<(u16, u8)>, - /// The live variables at the start of this block. Used for pruning phi nodes. - pub(crate) live_vars: HashSet<(u16, u8)>, /// The dominance frontiers of this block (where phi nodes must be inserted). pub(crate) dom_frontiers: HashSet, /// A stable list of operations performed in this block. @@ -61,12 +59,3 @@ impl Optimizer { } } } - -impl Program { - /// Check whether a variable is dead at the start of this block. Note that `false` does not mean - /// the variable is definitely live - just that it *may* be live and must be treated as such. - #[track_caller] - pub fn is_dead(&self, node: NodeIndex, var: (u16, u8)) -> bool { - !self[node].live_vars.contains(&var) - } -} diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index e1f9523ce..c33f46827 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -302,7 +302,7 @@ impl Optimizer { self.current_block = Some(next); // For loop constructs - self.program.insert_phi(header, i_id, range_loop.start.item); + self.insert_phi(header, i_id, range_loop.start.item); { let op = match range_loop.inclusive { true => Operator::LowerEqual, diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index 93269e5fb..8ae95936d 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -1,9 +1,10 @@ -use std::fmt::Display; +use std::{fmt::Display, rc::Rc}; use cubecl_core::ir::{FloatKind, IntKind, UIntKind}; use petgraph::visit::EdgeRef; use crate::{ + analyses::liveness::Liveness, gvn::{BlockSets, Constant, Expression, Instruction, Local, OpId, Value, ValueTable}, passes::var_id, ControlFlow, @@ -31,6 +32,10 @@ impl Display for Optimizer { f.write_str("\n\n")?; let global_nums = self.gvn.borrow(); + let liveness = self + .analyses + .try_get::() + .unwrap_or_else(|| Rc::new(Liveness::empty(self))); if DEBUG_GVN { writeln!(f, "# Value Table:")?; @@ -53,7 +58,7 @@ impl Display for Optimizer { if !bb.block_use.is_empty() { writeln!(f, " Uses: {:?}", bb.block_use)?; } - let live_vars = bb.live_vars.iter(); + let live_vars = liveness.at_block(node).iter(); let live_vars = live_vars.map(|it| format!("local({}, {})", it.0, it.1)); let live_vars = live_vars.collect::>(); writeln!(f, " Live variables: [{}]\n", live_vars.join(", "))?; diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index e9eefa3f2..696f7a06f 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -31,7 +31,7 @@ use std::{ sync::atomic::{AtomicUsize, Ordering}, }; -use analyses::Analyses; +use analyses::{liveness::Liveness, Analyses}; use cubecl_core::{ ir::{self as core, Branch, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -197,7 +197,6 @@ impl Optimizer { fn run_opt(&mut self, expand: Scope) { self.parse_graph(expand); self.split_critical_edges(); - self.analyze_liveness(); self.apply_pre_ssa_passes(); self.exempt_index_assign_locals(); self.ssa_transform(); @@ -209,7 +208,7 @@ impl Optimizer { let arrays_prop = AtomicCounter::new(0); CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone()); if arrays_prop.get() > 0 { - self.analyze_liveness(); + self.invalidate_analysis::(); self.ssa_transform(); self.apply_post_ssa_passes(); } @@ -317,8 +316,8 @@ impl Optimizer { } fn ssa_transform(&mut self) { - self.program.fill_dom_frontiers(); - self.program.place_phi_nodes(); + self.fill_dom_frontiers(); + self.place_phi_nodes(); self.version_program(); self.program.variables.clear(); for block in self.node_ids() { diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index e70ee8fad..cb23824e4 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -8,8 +8,8 @@ use cubecl_core::ir::{ConstantScalarValue, Instruction, Operation, VariableKind} use petgraph::{graph::NodeIndex, visit::EdgeRef}; use crate::{ - analyses::post_order::PostOrder, visit_noop, AtomicCounter, BasicBlock, BlockUse, ControlFlow, - Optimizer, + analyses::{liveness::Liveness, post_order::PostOrder}, + visit_noop, AtomicCounter, BasicBlock, BlockUse, ControlFlow, Optimizer, }; use super::OptimizerPass; @@ -205,8 +205,6 @@ fn merge_blocks(opt: &mut Optimizer) -> bool { let b_ops = block.ops.borrow().values().cloned().collect::>(); let s_ops = successor.ops.borrow().values().cloned().collect::>(); - new_block.live_vars.extend(block.live_vars.clone()); - new_block.live_vars.extend(successor.live_vars.clone()); new_block.phi_nodes.borrow_mut().extend(b_phi); new_block.phi_nodes.borrow_mut().extend(s_phi); new_block.ops.borrow_mut().extend(b_ops); @@ -229,6 +227,7 @@ fn merge_blocks(opt: &mut Optimizer) -> bool { *opt.program.node_weight_mut(block_idx).unwrap() = new_block; opt.program.remove_node(successors[0]); opt.invalidate_analysis::(); + opt.invalidate_analysis::(); update_references(opt, successors[0], block_idx); return true; } diff --git a/crates/cubecl-opt/src/passes/liveness.rs b/crates/cubecl-opt/src/passes/liveness.rs deleted file mode 100644 index d662b96f4..000000000 --- a/crates/cubecl-opt/src/passes/liveness.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::collections::{HashMap, HashSet, VecDeque}; - -use petgraph::graph::NodeIndex; - -use crate::{analyses::post_order::PostOrder, Optimizer}; - -#[derive(Clone)] -struct BlockSets { - gen: HashSet<(u16, u8)>, - kill: HashSet<(u16, u8)>, -} - -struct State { - worklist: VecDeque, - block_sets: HashMap, -} - -impl Optimizer { - /// Do a conservative block level liveness analysis - pub fn analyze_liveness(&mut self) { - let mut state = State { - worklist: VecDeque::from(self.analysis::().forward()), - block_sets: HashMap::new(), - }; - while let Some(block) = state.worklist.pop_front() { - self.analyze_block(block, &mut state); - } - } - - fn analyze_block(&mut self, block: NodeIndex, state: &mut State) { - let BlockSets { gen, kill } = self.block_sets(block, state); - - let mut live_vars = gen.clone(); - - for successor in self.successors(block) { - let successor = &self.program[successor].live_vars; - live_vars.extend(successor.difference(kill)); - } - - if live_vars != self.program[block].live_vars { - state.worklist.extend(self.predecessors(block)); - self.program[block].live_vars = live_vars; - } - } - - fn block_sets<'a>(&mut self, block: NodeIndex, state: &'a mut State) -> &'a BlockSets { - let block_sets = state.block_sets.entry(block); - block_sets.or_insert_with(|| self.calculate_block_sets(block)) - } - - fn calculate_block_sets(&mut self, block: NodeIndex) -> BlockSets { - let mut gen = HashSet::new(); - let mut kill = HashSet::new(); - - let ops = self.program[block].ops.clone(); - - for op in ops.borrow_mut().values_mut().rev() { - // Reads must be tracked after writes - self.visit_out(&mut op.out, |opt, var| { - if let Some(id) = opt.local_variable_id(var) { - kill.insert(id); - gen.remove(&id); - } - }); - self.visit_operation(&mut op.operation, |opt, var| { - if let Some(id) = opt.local_variable_id(var) { - gen.insert(id); - } - }); - } - - BlockSets { gen, kill } - } -} diff --git a/crates/cubecl-opt/src/passes/mod.rs b/crates/cubecl-opt/src/passes/mod.rs index c52ad6565..197fd3c46 100644 --- a/crates/cubecl-opt/src/passes/mod.rs +++ b/crates/cubecl-opt/src/passes/mod.rs @@ -7,7 +7,6 @@ mod in_bounds_analysis; mod index_merge; mod inlined_if_to_select; mod integer_range_analysis; -mod liveness; mod reduce_strength; pub use array_copy_propagate::*; diff --git a/crates/cubecl-opt/src/phi_frontiers.rs b/crates/cubecl-opt/src/phi_frontiers.rs index 80b3feda0..ee8e30a9b 100644 --- a/crates/cubecl-opt/src/phi_frontiers.rs +++ b/crates/cubecl-opt/src/phi_frontiers.rs @@ -1,25 +1,21 @@ use cubecl_core::ir::{Item, Variable, VariableKind}; -use petgraph::{algo::dominators::simple_fast, graph::NodeIndex, visit::EdgeRef, Direction}; +use petgraph::{algo::dominators::simple_fast, graph::NodeIndex}; -use super::{ - version::{PhiEntry, PhiInstruction}, - Program, -}; +use crate::{analyses::liveness::Liveness, Optimizer}; -impl Program { +use super::version::{PhiEntry, PhiInstruction}; + +impl Optimizer { /// Find dominance frontiers for each block pub fn fill_dom_frontiers(&mut self) { - let doms = simple_fast(&self.graph, self.root); - for node in self.node_indices().collect::>() { - let predecessors: Vec<_> = self - .edges_directed(node, Direction::Incoming) - .map(|it| it.source()) - .collect(); + let doms = simple_fast(&self.program.graph, self.program.root); + for node in self.node_ids() { + let predecessors = self.predecessors(node); if predecessors.len() >= 2 { for predecessor in predecessors { let mut runner = predecessor; while runner != doms.immediate_dominator(node).unwrap() { - self[runner].dom_frontiers.insert(node); + self.program[runner].dom_frontiers.insert(node); runner = doms.immediate_dominator(runner).unwrap(); } } @@ -29,21 +25,24 @@ impl Program { /// Places a phi node for each live variable at each frontier pub fn place_phi_nodes(&mut self) { - let keys: Vec<_> = self.variables.keys().cloned().collect(); + let keys: Vec<_> = self.program.variables.keys().cloned().collect(); + let liveness = self.analysis::(); for var in keys { let mut workset: Vec<_> = self - .node_indices() - .filter(|index| self[*index].writes.contains(&var)) + .node_ids() + .iter() + .filter(|index| self.program[**index].writes.contains(&var)) + .copied() .collect(); let mut considered = workset.clone(); let mut already_inserted = Vec::new(); while let Some(node) = workset.pop() { - for frontier in self[node].dom_frontiers.clone() { - if already_inserted.contains(&frontier) || self.is_dead(frontier, var) { + for frontier in self.program[node].dom_frontiers.clone() { + if already_inserted.contains(&frontier) || liveness.is_dead(frontier, var) { continue; } - self.insert_phi(frontier, var, self.variables[&var]); + self.insert_phi(frontier, var, self.program.variables[&var]); already_inserted.push(frontier); if !considered.contains(&frontier) { workset.push(frontier); @@ -64,17 +63,14 @@ impl Program { }, item, ); - let entries = self - .edges_directed(block, Direction::Incoming) - .map(|edge| edge.source()) - .map(|pred| PhiEntry { - block: pred, - value: var, - }); + let entries = self.predecessors(block).into_iter().map(|pred| PhiEntry { + block: pred, + value: var, + }); let phi = PhiInstruction { out: var, entries: entries.collect(), }; - self[block].phi_nodes.borrow_mut().push(phi); + self.program[block].phi_nodes.borrow_mut().push(phi); } } From b67fe0c78fe932a2933dea83171bd12218b9855a Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 17:34:15 +0100 Subject: [PATCH 04/17] Invalidate structure in places where it was missing before --- crates/cubecl-opt/src/analyses/base.rs | 6 ++++++ crates/cubecl-opt/src/control_flow.rs | 1 + crates/cubecl-opt/src/lib.rs | 2 ++ crates/cubecl-opt/src/passes/dead_code.rs | 4 +++- crates/cubecl-opt/src/passes/inlined_if_to_select.rs | 7 ++----- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs index 731cfee48..c1777a6ea 100644 --- a/crates/cubecl-opt/src/analyses/base.rs +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -4,6 +4,8 @@ use type_map::TypeMap; use crate::Optimizer; +use super::post_order::PostOrder; + pub trait Analysis { fn init(opt: &mut Optimizer) -> Self; } @@ -43,4 +45,8 @@ impl Optimizer { pub fn invalidate_analysis(&self) { self.analyses.invalidate::(); } + + pub fn invalidate_structure(&self) { + self.invalidate_analysis::(); + } } diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index c33f46827..cd081dc37 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -349,6 +349,7 @@ impl Optimizer { let new_block = self.program.add_node(BasicBlock::default()); self.program.add_edge(block, new_block, ()); self.program.add_edge(new_block, *successor, ()); + self.invalidate_structure(); update_phi(self, *successor, block, new_block); update_control_flow(self, block, *successor, new_block); } diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 696f7a06f..bdfc0ba43 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -241,6 +241,7 @@ impl Optimizer { if let Some(current_block) = self.current_block { self.program.add_edge(current_block, self.ret, ()); } + self.invalidate_structure(); } fn apply_pre_ssa_passes(&mut self) { @@ -444,6 +445,7 @@ impl Optimizer { let new_ret = self.program.add_node(BasicBlock::default()); self.program.add_edge(new_ret, self.ret, ()); self.ret = new_ret; + self.invalidate_structure(); new_ret } else { self.ret diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index cb23824e4..40d0f1da0 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -96,6 +96,7 @@ impl OptimizerPass for EliminateConstBranches { } *control_flow.borrow_mut() = ControlFlow::None; + opt.invalidate_structure(); changes.inc(); } ControlFlow::Switch { @@ -119,6 +120,7 @@ impl OptimizerPass for EliminateConstBranches { opt.program.remove_edge(edge); } *control_flow.borrow_mut() = ControlFlow::None; + opt.invalidate_structure(); changes.inc(); } _ => {} @@ -226,7 +228,7 @@ fn merge_blocks(opt: &mut Optimizer) -> bool { } *opt.program.node_weight_mut(block_idx).unwrap() = new_block; opt.program.remove_node(successors[0]); - opt.invalidate_analysis::(); + opt.invalidate_structure(); opt.invalidate_analysis::(); update_references(opt, successors[0], block_idx); return true; diff --git a/crates/cubecl-opt/src/passes/inlined_if_to_select.rs b/crates/cubecl-opt/src/passes/inlined_if_to_select.rs index 5e6eb0251..6e6e47210 100644 --- a/crates/cubecl-opt/src/passes/inlined_if_to_select.rs +++ b/crates/cubecl-opt/src/passes/inlined_if_to_select.rs @@ -3,10 +3,7 @@ use std::mem::take; use cubecl_core::ir::{Instruction, Operator, Select}; use petgraph::{graph::NodeIndex, visit::EdgeRef}; -use crate::{ - analyses::post_order::PostOrder, passes::update_references, AtomicCounter, ControlFlow, - Optimizer, -}; +use crate::{passes::update_references, AtomicCounter, ControlFlow, Optimizer}; use super::OptimizerPass; @@ -82,11 +79,11 @@ impl OptimizerPass for EmptyBranchToSelect { opt.program.remove_node(then); opt.program.remove_node(or_else); opt.program.remove_node(merge); - opt.invalidate_analysis::(); for merge_successor in merge_successors { opt.program.add_edge(block, merge_successor, ()); } *opt.program[block].control_flow.borrow_mut() = merge_control; + opt.invalidate_structure(); update_references(opt, merge, block); return true; } From f5f8a811d6200344bd11ed80d420a6c31dae495f Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 17:55:17 +0100 Subject: [PATCH 05/17] Migrate dominators --- crates/cubecl-opt/src/analyses/base.rs | 9 ++++- crates/cubecl-opt/src/analyses/dominators.rs | 39 ++++++++++++++++++++ crates/cubecl-opt/src/analyses/mod.rs | 1 + crates/cubecl-opt/src/gvn/analysis.rs | 34 ++++++++--------- crates/cubecl-opt/src/gvn/apply.rs | 7 ++-- crates/cubecl-opt/src/gvn/base.rs | 15 +------- crates/cubecl-opt/src/phi_frontiers.rs | 9 +++-- 7 files changed, 74 insertions(+), 40 deletions(-) create mode 100644 crates/cubecl-opt/src/analyses/dominators.rs diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs index c1777a6ea..064540581 100644 --- a/crates/cubecl-opt/src/analyses/base.rs +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -4,7 +4,11 @@ use type_map::TypeMap; use crate::Optimizer; -use super::post_order::PostOrder; +use super::{ + dominators::{Dominators, PostDominators}, + liveness::Liveness, + post_order::PostOrder, +}; pub trait Analysis { fn init(opt: &mut Optimizer) -> Self; @@ -48,5 +52,8 @@ impl Optimizer { pub fn invalidate_structure(&self) { self.invalidate_analysis::(); + self.invalidate_analysis::(); + self.invalidate_analysis::(); + self.invalidate_analysis::(); } } diff --git a/crates/cubecl-opt/src/analyses/dominators.rs b/crates/cubecl-opt/src/analyses/dominators.rs new file mode 100644 index 000000000..b2d43618f --- /dev/null +++ b/crates/cubecl-opt/src/analyses/dominators.rs @@ -0,0 +1,39 @@ +use std::ops::Deref; + +use crate::NodeIndex; +use petgraph::algo::dominators; + +use super::Analysis; + +pub struct Dominators(dominators::Dominators); +pub struct PostDominators(dominators::Dominators); + +impl Deref for Dominators { + type Target = dominators::Dominators; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Deref for PostDominators { + type Target = dominators::Dominators; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Analysis for Dominators { + fn init(opt: &mut crate::Optimizer) -> Self { + Dominators(dominators::simple_fast(&opt.program.graph, opt.entry())) + } +} + +impl Analysis for PostDominators { + fn init(opt: &mut crate::Optimizer) -> Self { + let mut reversed = opt.program.graph.clone(); + reversed.reverse(); + PostDominators(dominators::simple_fast(&reversed, opt.ret)) + } +} diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs index 75d60ba6a..e5c216ae1 100644 --- a/crates/cubecl-opt/src/analyses/mod.rs +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -1,4 +1,5 @@ mod base; +pub mod dominators; pub mod liveness; pub mod post_order; diff --git a/crates/cubecl-opt/src/gvn/analysis.rs b/crates/cubecl-opt/src/gvn/analysis.rs index 4e54f64ed..9f0684bc0 100644 --- a/crates/cubecl-opt/src/gvn/analysis.rs +++ b/crates/cubecl-opt/src/gvn/analysis.rs @@ -1,13 +1,12 @@ use std::collections::{HashMap, HashSet, LinkedList}; -use petgraph::{ - algo::dominators::{self}, - graph::NodeIndex, - Graph, -}; +use crate::NodeIndex; use smallvec::SmallVec; -use crate::{BasicBlock, Optimizer}; +use crate::{ + analyses::dominators::{Dominators, PostDominators}, + Optimizer, +}; use super::{convert::value_of_var, Expression, GvnPass, Value, ValueTable}; @@ -36,13 +35,9 @@ pub struct BlockSets { /// (which is required for the `Compiler` implementation) impl Default for GvnPass { fn default() -> Self { - let mut dummy = Graph::::new(); - let root = dummy.add_node(BasicBlock::default()); Self { values: Default::default(), block_sets: Default::default(), - dominators: dominators::simple_fast(&dummy, root), - post_doms: dominators::simple_fast(&dummy, root), } } } @@ -52,12 +47,10 @@ impl GvnPass { /// 1. Forward DFA that generates the available expressions, values and leaders for each block /// 2. Backward fixed-point DFA that generates the anticipated expressions/antileaders for each /// block - pub fn build_sets(&mut self, opt: &Optimizer) { - self.build_block_sets_fwd(opt, self.dominators.root(), HashMap::new()); + pub fn build_sets(&mut self, opt: &mut Optimizer) { + self.build_block_sets_fwd(opt, opt.entry(), HashMap::new()); let mut build_passes = 0; - while self.build_block_sets_bckwd(opt, self.post_doms.root()) - && build_passes < MAX_SET_PASSES - { + while self.build_block_sets_bckwd(opt, opt.ret) && build_passes < MAX_SET_PASSES { build_passes += 1; } @@ -86,7 +79,7 @@ impl GvnPass { /// variables that represent them are also available there. fn build_block_sets_fwd( &mut self, - opt: &Optimizer, + opt: &mut Optimizer, block: NodeIndex, mut leaders: HashMap, ) { @@ -100,6 +93,8 @@ impl GvnPass { // Values already added in this block. Used to deduplicate locally. let mut added_exprs = HashSet::new(); + let dominators = opt.analysis::(); + // Number phi outputs and add the out var as a leader for that value for phi in opt.program[block].phi_nodes.borrow().iter() { let (num, val) = self.values.lookup_or_add_phi(phi); @@ -135,7 +130,7 @@ impl GvnPass { antic_in: Default::default(), }; self.block_sets.insert(block, sets); - let successors: Vec<_> = self.dominators.immediately_dominated_by(block).collect(); + let successors: Vec<_> = dominators.immediately_dominated_by(block).collect(); for successor in successors { // Work around dominator bug if successor != block { @@ -147,8 +142,9 @@ impl GvnPass { /// Do a fixed point data backward flow analysis to find expected expressions at any given /// program point. Iterates through the post-dominator tree because it's the fastest way to /// converge. - fn build_block_sets_bckwd(&mut self, opt: &Optimizer, current: NodeIndex) -> bool { + fn build_block_sets_bckwd(&mut self, opt: &mut Optimizer, current: NodeIndex) -> bool { let mut changed = false; + let post_doms = opt.analysis::(); let successors = opt.successors(current); // Since we have no critical edges, if successors > 1 then they must have only one entry, @@ -224,7 +220,7 @@ impl GvnPass { } self.block_sets.get_mut(¤t).unwrap().antic_in = result; - let predecessors: Vec<_> = self.post_doms.immediately_dominated_by(current).collect(); + let predecessors: Vec<_> = post_doms.immediately_dominated_by(current).collect(); for predecessor in predecessors { // Work around dominator bug if predecessor != current { diff --git a/crates/cubecl-opt/src/gvn/apply.rs b/crates/cubecl-opt/src/gvn/apply.rs index bfccf1658..6ff99dd19 100644 --- a/crates/cubecl-opt/src/gvn/apply.rs +++ b/crates/cubecl-opt/src/gvn/apply.rs @@ -4,6 +4,7 @@ use cubecl_core::ir::{self, Operation}; use petgraph::graph::NodeIndex; use crate::{ + analyses::dominators::Dominators, gvn::{convert::value_of_var, phi_translate}, version::PhiEntry, AtomicCounter, Optimizer, PhiInstruction, @@ -21,7 +22,7 @@ impl GvnPass { let mut new_expr = HashMap::new(); - while self.insert_block(opt, self.dominators.root(), &mut new_expr, changes) { + while self.insert_block(opt, opt.entry(), &mut new_expr, changes) { loops += 1; } let inserted = changes.get() - changes_pre; @@ -37,6 +38,7 @@ impl GvnPass { changes: &AtomicCounter, ) -> bool { let mut changed = false; + let dominators = opt.analysis::(); let predecessors = opt.predecessors(current); if predecessors.len() > 1 { @@ -120,8 +122,7 @@ impl GvnPass { opt.program[current].phi_nodes.borrow_mut().extend(new_phis); } - let children = self - .dominators + let children = dominators .immediately_dominated_by(current) .collect::>(); for child in children { diff --git a/crates/cubecl-opt/src/gvn/base.rs b/crates/cubecl-opt/src/gvn/base.rs index 02fbfe94a..2ced26bdb 100644 --- a/crates/cubecl-opt/src/gvn/base.rs +++ b/crates/cubecl-opt/src/gvn/base.rs @@ -5,10 +5,7 @@ use cubecl_core::{ prelude::CubePrimitive, }; use float_ord::FloatOrd; -use petgraph::{ - algo::dominators::{self, Dominators}, - graph::NodeIndex, -}; +use petgraph::graph::NodeIndex; use smallvec::SmallVec; use crate::{passes::OptimizerPass, AtomicCounter, Optimizer, PhiInstruction}; @@ -25,8 +22,6 @@ impl OptimizerPass for GvnPass { pub struct GvnPass { pub values: ValueTable, pub block_sets: HashMap, - pub dominators: Dominators, - pub post_doms: Dominators, } impl GvnPass { @@ -38,18 +33,10 @@ impl GvnPass { /// 4. Replace fully redundant expressions with simple assignments from the leader of that /// expression to `out` pub fn run(&mut self, opt: &mut Optimizer, changes: &AtomicCounter) { - self.build_dominators(opt); self.build_sets(opt); self.insert(opt, changes); self.eliminate(opt, changes); } - - fn build_dominators(&mut self, opt: &Optimizer) { - self.dominators = dominators::simple_fast(&opt.program.graph, opt.entry()); - let mut rev_graph = opt.program.graph.clone(); - rev_graph.reverse(); - self.post_doms = dominators::simple_fast(&rev_graph, opt.ret); - } } /// A global value table that maps expressions and locals to the values they represent. diff --git a/crates/cubecl-opt/src/phi_frontiers.rs b/crates/cubecl-opt/src/phi_frontiers.rs index ee8e30a9b..f03e14d5a 100644 --- a/crates/cubecl-opt/src/phi_frontiers.rs +++ b/crates/cubecl-opt/src/phi_frontiers.rs @@ -1,14 +1,17 @@ use cubecl_core::ir::{Item, Variable, VariableKind}; -use petgraph::{algo::dominators::simple_fast, graph::NodeIndex}; +use petgraph::graph::NodeIndex; -use crate::{analyses::liveness::Liveness, Optimizer}; +use crate::{ + analyses::{dominators::Dominators, liveness::Liveness}, + Optimizer, +}; use super::version::{PhiEntry, PhiInstruction}; impl Optimizer { /// Find dominance frontiers for each block pub fn fill_dom_frontiers(&mut self) { - let doms = simple_fast(&self.program.graph, self.program.root); + let doms = self.analysis::(); for node in self.node_ids() { let predecessors = self.predecessors(node); if predecessors.len() >= 2 { From decd91e11584cf399d800a1a63c4ed754e87f3b1 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 18:19:44 +0100 Subject: [PATCH 06/17] Migrate dominance frontiers --- .../src/analyses/dominance_frontiers.rs | 49 +++++++++++++++++++ crates/cubecl-opt/src/analyses/mod.rs | 2 + crates/cubecl-opt/src/analyses/writes.rs | 0 crates/cubecl-opt/src/block.rs | 3 -- crates/cubecl-opt/src/lib.rs | 5 +- crates/cubecl-opt/src/phi_frontiers.rs | 23 ++------- 6 files changed, 57 insertions(+), 25 deletions(-) create mode 100644 crates/cubecl-opt/src/analyses/dominance_frontiers.rs create mode 100644 crates/cubecl-opt/src/analyses/writes.rs diff --git a/crates/cubecl-opt/src/analyses/dominance_frontiers.rs b/crates/cubecl-opt/src/analyses/dominance_frontiers.rs new file mode 100644 index 000000000..726ff70f0 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/dominance_frontiers.rs @@ -0,0 +1,49 @@ +use crate::{NodeIndex, Optimizer}; +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, +}; + +use super::{dominators::Dominators, Analysis}; + +pub struct DomFrontiers { + /// The dominance frontiers of each block (where phi nodes must be inserted). + dom_frontiers: HashMap>, +} + +impl Deref for DomFrontiers { + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.dom_frontiers + } +} + +impl DomFrontiers { + /// Find dominance frontiers for each block + pub fn new(opt: &mut Optimizer) -> Self { + let doms = opt.analysis::(); + let nodes = opt.node_ids().into_iter().map(|it| (it, HashSet::new())); + let mut dom_frontiers: HashMap> = nodes.collect(); + + for node in opt.node_ids() { + let predecessors = opt.predecessors(node); + if predecessors.len() >= 2 { + for predecessor in predecessors { + let mut runner = predecessor; + while runner != doms.immediate_dominator(node).unwrap() { + dom_frontiers.get_mut(&runner).unwrap().insert(node); + runner = doms.immediate_dominator(runner).unwrap(); + } + } + } + } + Self { dom_frontiers } + } +} + +impl Analysis for DomFrontiers { + fn init(opt: &mut Optimizer) -> Self { + DomFrontiers::new(opt) + } +} diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs index e5c216ae1..63ebdf924 100644 --- a/crates/cubecl-opt/src/analyses/mod.rs +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -1,6 +1,8 @@ mod base; +pub mod dominance_frontiers; pub mod dominators; pub mod liveness; pub mod post_order; +pub mod writes; pub use base::*; diff --git a/crates/cubecl-opt/src/analyses/writes.rs b/crates/cubecl-opt/src/analyses/writes.rs new file mode 100644 index 000000000..e69de29bb diff --git a/crates/cubecl-opt/src/block.rs b/crates/cubecl-opt/src/block.rs index d220ea64e..c7f4b34f8 100644 --- a/crates/cubecl-opt/src/block.rs +++ b/crates/cubecl-opt/src/block.rs @@ -1,7 +1,6 @@ use std::{cell::RefCell, collections::HashSet, rc::Rc}; use cubecl_core::ir::{Instruction, Variable}; -use petgraph::graph::NodeIndex; use stable_vec::StableVec; use crate::{version::PhiInstruction, ControlFlow, Optimizer}; @@ -21,8 +20,6 @@ pub struct BasicBlock { pub phi_nodes: Rc>>, /// The variables written to by this block. Only set during the SSA transformation. pub(crate) writes: HashSet<(u16, u8)>, - /// The dominance frontiers of this block (where phi nodes must be inserted). - pub(crate) dom_frontiers: HashSet, /// A stable list of operations performed in this block. pub ops: Rc>>, /// The control flow that terminates this block. diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index bdfc0ba43..aa8492598 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -31,7 +31,7 @@ use std::{ sync::atomic::{AtomicUsize, Ordering}, }; -use analyses::{liveness::Liveness, Analyses}; +use analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness, Analyses}; use cubecl_core::{ ir::{self as core, Branch, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -317,13 +317,12 @@ impl Optimizer { } fn ssa_transform(&mut self) { - self.fill_dom_frontiers(); self.place_phi_nodes(); self.version_program(); self.program.variables.clear(); for block in self.node_ids() { self.program[block].writes.clear(); - self.program[block].dom_frontiers.clear(); + self.invalidate_analysis::(); } } diff --git a/crates/cubecl-opt/src/phi_frontiers.rs b/crates/cubecl-opt/src/phi_frontiers.rs index f03e14d5a..fb83e3b9a 100644 --- a/crates/cubecl-opt/src/phi_frontiers.rs +++ b/crates/cubecl-opt/src/phi_frontiers.rs @@ -2,34 +2,19 @@ use cubecl_core::ir::{Item, Variable, VariableKind}; use petgraph::graph::NodeIndex; use crate::{ - analyses::{dominators::Dominators, liveness::Liveness}, + analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness}, Optimizer, }; use super::version::{PhiEntry, PhiInstruction}; impl Optimizer { - /// Find dominance frontiers for each block - pub fn fill_dom_frontiers(&mut self) { - let doms = self.analysis::(); - for node in self.node_ids() { - let predecessors = self.predecessors(node); - if predecessors.len() >= 2 { - for predecessor in predecessors { - let mut runner = predecessor; - while runner != doms.immediate_dominator(node).unwrap() { - self.program[runner].dom_frontiers.insert(node); - runner = doms.immediate_dominator(runner).unwrap(); - } - } - } - } - } - /// Places a phi node for each live variable at each frontier pub fn place_phi_nodes(&mut self) { let keys: Vec<_> = self.program.variables.keys().cloned().collect(); let liveness = self.analysis::(); + let dom_frontiers = self.analysis::(); + for var in keys { let mut workset: Vec<_> = self .node_ids() @@ -41,7 +26,7 @@ impl Optimizer { let mut already_inserted = Vec::new(); while let Some(node) = workset.pop() { - for frontier in self.program[node].dom_frontiers.clone() { + for frontier in dom_frontiers[&node].clone() { if already_inserted.contains(&frontier) || liveness.is_dead(frontier, var) { continue; } From ddf0f292e5ba0822d8c093592063a1e454d6c4be Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 19:02:29 +0100 Subject: [PATCH 07/17] Migrate GVN --- crates/cubecl-opt/src/analyses/base.rs | 3 ++ crates/cubecl-opt/src/analyses/writes.rs | 43 +++++++++++++++++++ crates/cubecl-opt/src/block.rs | 4 +- crates/cubecl-opt/src/control_flow.rs | 3 +- crates/cubecl-opt/src/debug.rs | 4 +- crates/cubecl-opt/src/gvn/analysis.rs | 38 +++++++++------- crates/cubecl-opt/src/gvn/apply.rs | 4 +- crates/cubecl-opt/src/gvn/base.rs | 18 ++++---- crates/cubecl-opt/src/instructions.rs | 6 --- crates/cubecl-opt/src/lib.rs | 16 ++----- .../src/passes/array_copy_propagate.rs | 5 +-- crates/cubecl-opt/src/phi_frontiers.rs | 5 ++- 12 files changed, 92 insertions(+), 57 deletions(-) diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs index 064540581..7d7b39da5 100644 --- a/crates/cubecl-opt/src/analyses/base.rs +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -10,7 +10,10 @@ use super::{ post_order::PostOrder, }; +/// An analysis used by optimization passes. Unlike optimization passes, analyses can have state +/// and persist until they're invalidated. pub trait Analysis { + /// Perform the analysis for the current optimizer state and return the persistent analysis state fn init(opt: &mut Optimizer) -> Self; } diff --git a/crates/cubecl-opt/src/analyses/writes.rs b/crates/cubecl-opt/src/analyses/writes.rs index e69de29bb..a35ed07f6 100644 --- a/crates/cubecl-opt/src/analyses/writes.rs +++ b/crates/cubecl-opt/src/analyses/writes.rs @@ -0,0 +1,43 @@ +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, +}; + +use crate::{NodeIndex, Optimizer}; + +use super::Analysis; + +pub struct Writes { + /// The variables written to by each block. + writes: HashMap>, +} + +impl Deref for Writes { + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.writes + } +} + +impl Writes { + pub fn new(opt: &mut Optimizer) -> Self { + let nodes = opt.node_ids().into_iter().map(|it| (it, HashSet::new())); + let mut writes: HashMap> = nodes.collect(); + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for inst in ops.borrow().values() { + if let Some(id) = inst.out.as_ref().and_then(|it| opt.local_variable_id(it)) { + writes.get_mut(&block).unwrap().insert(id); + } + } + } + Writes { writes } + } +} + +impl Analysis for Writes { + fn init(opt: &mut crate::Optimizer) -> Self { + Writes::new(opt) + } +} diff --git a/crates/cubecl-opt/src/block.rs b/crates/cubecl-opt/src/block.rs index c7f4b34f8..3766fdaa7 100644 --- a/crates/cubecl-opt/src/block.rs +++ b/crates/cubecl-opt/src/block.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, collections::HashSet, rc::Rc}; +use std::{cell::RefCell, rc::Rc}; use cubecl_core::ir::{Instruction, Variable}; use stable_vec::StableVec; @@ -18,8 +18,6 @@ pub struct BasicBlock { pub(crate) block_use: Vec, /// The phi nodes that are required to be generated at the start of this block. pub phi_nodes: Rc>>, - /// The variables written to by this block. Only set during the SSA transformation. - pub(crate) writes: HashSet<(u16, u8)>, /// A stable list of operations performed in this block. pub ops: Rc>>, /// The control flow that terminates this block. diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index cd081dc37..c00715d70 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -259,8 +259,7 @@ impl Optimizer { let i = range_loop.i; self.program.variables.insert(i_id, i.item); - let mut assign = Instruction::new(Operation::Copy(range_loop.start), i); - self.visit_out(&mut assign.out, |opt, var| opt.write_var(var)); + let assign = Instruction::new(Operation::Copy(range_loop.start), i); self.current_block_mut().ops.borrow_mut().push(assign); let current_block = self.current_block.unwrap(); diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index 8ae95936d..aebd47cf5 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -5,7 +5,7 @@ use petgraph::visit::EdgeRef; use crate::{ analyses::liveness::Liveness, - gvn::{BlockSets, Constant, Expression, Instruction, Local, OpId, Value, ValueTable}, + gvn::{BlockSets, Constant, Expression, GvnState, Instruction, Local, OpId, Value, ValueTable}, passes::var_id, ControlFlow, }; @@ -31,7 +31,7 @@ impl Display for Optimizer { } f.write_str("\n\n")?; - let global_nums = self.gvn.borrow(); + let global_nums = self.analyses.try_get::().unwrap_or_default(); let liveness = self .analyses .try_get::() diff --git a/crates/cubecl-opt/src/gvn/analysis.rs b/crates/cubecl-opt/src/gvn/analysis.rs index 9f0684bc0..9af5cd24e 100644 --- a/crates/cubecl-opt/src/gvn/analysis.rs +++ b/crates/cubecl-opt/src/gvn/analysis.rs @@ -1,6 +1,9 @@ -use std::collections::{HashMap, HashSet, LinkedList}; +use std::{ + cell::RefCell, + collections::{HashMap, HashSet, LinkedList}, +}; -use crate::NodeIndex; +use crate::{analyses::Analysis, NodeIndex}; use smallvec::SmallVec; use crate::{ @@ -8,10 +11,26 @@ use crate::{ Optimizer, }; -use super::{convert::value_of_var, Expression, GvnPass, Value, ValueTable}; +use super::{convert::value_of_var, Expression, Value, ValueTable}; const MAX_SET_PASSES: usize = 10; +pub struct GlobalValues(pub RefCell); + +#[derive(Debug, Clone, Default)] +pub struct GvnState { + pub values: ValueTable, + pub block_sets: HashMap, +} + +impl Analysis for GlobalValues { + fn init(opt: &mut Optimizer) -> Self { + let mut this = GvnState::default(); + this.build_sets(opt); + GlobalValues(RefCell::new(this)) + } +} + /// The set annotations for a given block #[derive(Debug, Clone, Default)] pub struct BlockSets { @@ -31,18 +50,7 @@ pub struct BlockSets { pub antic_in: LinkedList<(u32, Expression)>, } -/// Needed for Optimizer `Default`, which is required for SpirvCompiler to implement `Default` -/// (which is required for the `Compiler` implementation) -impl Default for GvnPass { - fn default() -> Self { - Self { - values: Default::default(), - block_sets: Default::default(), - } - } -} - -impl GvnPass { +impl GvnState { /// Build set annotations for each block. Executes two steps: /// 1. Forward DFA that generates the available expressions, values and leaders for each block /// 2. Backward fixed-point DFA that generates the anticipated expressions/antileaders for each diff --git a/crates/cubecl-opt/src/gvn/apply.rs b/crates/cubecl-opt/src/gvn/apply.rs index 6ff99dd19..a6c2baf3a 100644 --- a/crates/cubecl-opt/src/gvn/apply.rs +++ b/crates/cubecl-opt/src/gvn/apply.rs @@ -10,9 +10,9 @@ use crate::{ AtomicCounter, Optimizer, PhiInstruction, }; -use super::GvnPass; +use super::GvnState; -impl GvnPass { +impl GvnState { /// Find places where an expression is partially but not fully available, and hoist the /// computation into the blocks that do not currently have the value available to make the /// expression fully redundant diff --git a/crates/cubecl-opt/src/gvn/base.rs b/crates/cubecl-opt/src/gvn/base.rs index 2ced26bdb..b026913e7 100644 --- a/crates/cubecl-opt/src/gvn/base.rs +++ b/crates/cubecl-opt/src/gvn/base.rs @@ -10,7 +10,10 @@ use smallvec::SmallVec; use crate::{passes::OptimizerPass, AtomicCounter, Optimizer, PhiInstruction}; -use super::{convert::value_of_var, BlockSets}; +use super::{convert::value_of_var, GlobalValues}; + +#[derive(Debug, Clone, Default)] +pub struct GvnPass; impl OptimizerPass for GvnPass { fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { @@ -18,12 +21,6 @@ impl OptimizerPass for GvnPass { } } -#[derive(Debug, Clone)] -pub struct GvnPass { - pub values: ValueTable, - pub block_sets: HashMap, -} - impl GvnPass { /// Run the GVN-PRE algorithm /// 1. Build forward and backward dominator trees @@ -33,9 +30,10 @@ impl GvnPass { /// 4. Replace fully redundant expressions with simple assignments from the leader of that /// expression to `out` pub fn run(&mut self, opt: &mut Optimizer, changes: &AtomicCounter) { - self.build_sets(opt); - self.insert(opt, changes); - self.eliminate(opt, changes); + let analysis = opt.analysis::(); + + analysis.0.borrow_mut().insert(opt, changes); + analysis.0.borrow_mut().eliminate(opt, changes); } } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 5169405cc..1bb1a1b5e 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -257,10 +257,4 @@ impl Optimizer { visit_read(self, &mut binop.lhs); visit_read(self, &mut binop.rhs); } - - pub fn write_var(&mut self, var: &mut Variable) { - if let Some(id) = self.local_variable_id(var) { - self.current_block_mut().writes.insert(id); - } - } } diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index aa8492598..b8b767fd5 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -24,14 +24,13 @@ //! use std::{ - cell::RefCell, collections::{HashMap, VecDeque}, ops::{Deref, DerefMut}, rc::Rc, sync::atomic::{AtomicUsize, Ordering}, }; -use analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness, Analyses}; +use analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness, writes::Writes, Analyses}; use cubecl_core::{ ir::{self as core, Branch, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -157,7 +156,6 @@ pub struct Optimizer { pub(crate) cube_dim: CubeDim, /// The execution mode, `Unchecked` skips bounds check optimizations. pub(crate) mode: ExecutionMode, - pub(crate) gvn: Rc>, analyses: Rc, } @@ -172,7 +170,6 @@ impl Default for Optimizer { root_scope: Scope::root(), cube_dim: Default::default(), mode: Default::default(), - gvn: Default::default(), analyses: Default::default(), } } @@ -214,8 +211,7 @@ impl Optimizer { } let gvn_count = AtomicCounter::new(0); - let gvn = self.gvn.clone(); - gvn.borrow_mut().apply_post_ssa(self, gvn_count.clone()); + GvnPass.apply_post_ssa(self, gvn_count.clone()); ReduceStrength.apply_post_ssa(self, gvn_count.clone()); CopyTransform.apply_post_ssa(self, gvn_count.clone()); @@ -320,10 +316,8 @@ impl Optimizer { self.place_phi_nodes(); self.version_program(); self.program.variables.clear(); - for block in self.node_ids() { - self.program[block].writes.clear(); - self.invalidate_analysis::(); - } + self.invalidate_analysis::(); + self.invalidate_analysis::(); } /// Mutable reference to the current basic block @@ -403,11 +397,9 @@ impl Optimizer { const_len, }, ); - self.visit_out(&mut instruction.out, |opt, var| opt.write_var(var)); self.current_block_mut().ops.borrow_mut().push(instruction); } _ => { - self.visit_out(&mut instruction.out, |opt, var| opt.write_var(var)); self.current_block_mut().ops.borrow_mut().push(instruction); } } diff --git a/crates/cubecl-opt/src/passes/array_copy_propagate.rs b/crates/cubecl-opt/src/passes/array_copy_propagate.rs index a1e98ed55..8830faa90 100644 --- a/crates/cubecl-opt/src/passes/array_copy_propagate.rs +++ b/crates/cubecl-opt/src/passes/array_copy_propagate.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use cubecl_core::ir::{Instruction, Item, Operation, Operator, Variable, VariableKind}; -use crate::{AtomicCounter, Optimizer}; +use crate::{analyses::writes::Writes, AtomicCounter, Optimizer}; use super::OptimizerPass; @@ -113,9 +113,8 @@ fn replace_const_arrays(opt: &mut Optimizer, arr_id: (u16, u8), vars: &[Variable if (id, depth) == arr_id { let const_index = assign.lhs.as_const().unwrap().as_i64() as usize; let out = vars[const_index]; - let out_id = opt.local_variable_id(&out).unwrap(); *op = Instruction::new(Operation::Copy(assign.rhs), out); - opt.program[block].writes.insert(out_id); + opt.invalidate_analysis::(); } } } diff --git a/crates/cubecl-opt/src/phi_frontiers.rs b/crates/cubecl-opt/src/phi_frontiers.rs index fb83e3b9a..dbd2081b7 100644 --- a/crates/cubecl-opt/src/phi_frontiers.rs +++ b/crates/cubecl-opt/src/phi_frontiers.rs @@ -2,7 +2,7 @@ use cubecl_core::ir::{Item, Variable, VariableKind}; use petgraph::graph::NodeIndex; use crate::{ - analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness}, + analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness, writes::Writes}, Optimizer, }; @@ -12,6 +12,7 @@ impl Optimizer { /// Places a phi node for each live variable at each frontier pub fn place_phi_nodes(&mut self) { let keys: Vec<_> = self.program.variables.keys().cloned().collect(); + let writes = self.analysis::(); let liveness = self.analysis::(); let dom_frontiers = self.analysis::(); @@ -19,7 +20,7 @@ impl Optimizer { let mut workset: Vec<_> = self .node_ids() .iter() - .filter(|index| self.program[**index].writes.contains(&var)) + .filter(|index| writes[*index].contains(&var)) .copied() .collect(); let mut considered = workset.clone(); From 8decc52b9aba0522fdac10c91515ec4a42a8cf44 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 1 Jan 2025 19:08:37 +0100 Subject: [PATCH 08/17] Cleanup --- crates/cubecl-opt/src/analyses/base.rs | 2 +- .../{dominance_frontiers.rs => dominance.rs} | 42 ++++++++++++++++++- crates/cubecl-opt/src/analyses/dominators.rs | 39 ----------------- crates/cubecl-opt/src/analyses/mod.rs | 3 +- crates/cubecl-opt/src/gvn/analysis.rs | 2 +- crates/cubecl-opt/src/gvn/apply.rs | 2 +- crates/cubecl-opt/src/lib.rs | 6 +-- crates/cubecl-opt/src/phi_frontiers.rs | 2 +- 8 files changed, 48 insertions(+), 50 deletions(-) rename crates/cubecl-opt/src/analyses/{dominance_frontiers.rs => dominance.rs} (59%) delete mode 100644 crates/cubecl-opt/src/analyses/dominators.rs diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs index 7d7b39da5..93a05e0d1 100644 --- a/crates/cubecl-opt/src/analyses/base.rs +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -5,7 +5,7 @@ use type_map::TypeMap; use crate::Optimizer; use super::{ - dominators::{Dominators, PostDominators}, + dominance::{Dominators, PostDominators}, liveness::Liveness, post_order::PostOrder, }; diff --git a/crates/cubecl-opt/src/analyses/dominance_frontiers.rs b/crates/cubecl-opt/src/analyses/dominance.rs similarity index 59% rename from crates/cubecl-opt/src/analyses/dominance_frontiers.rs rename to crates/cubecl-opt/src/analyses/dominance.rs index 726ff70f0..b58d05a4a 100644 --- a/crates/cubecl-opt/src/analyses/dominance_frontiers.rs +++ b/crates/cubecl-opt/src/analyses/dominance.rs @@ -1,11 +1,49 @@ -use crate::{NodeIndex, Optimizer}; use std::{ collections::{HashMap, HashSet}, ops::Deref, }; -use super::{dominators::Dominators, Analysis}; +use crate::{NodeIndex, Optimizer}; +use petgraph::algo::dominators; + +use super::Analysis; + +/// Dominator tree for the program graph +pub struct Dominators(dominators::Dominators); +/// Post dominator tree for the program graph +pub struct PostDominators(dominators::Dominators); + +impl Deref for Dominators { + type Target = dominators::Dominators; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Deref for PostDominators { + type Target = dominators::Dominators; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Analysis for Dominators { + fn init(opt: &mut crate::Optimizer) -> Self { + Dominators(dominators::simple_fast(&opt.program.graph, opt.entry())) + } +} + +impl Analysis for PostDominators { + fn init(opt: &mut crate::Optimizer) -> Self { + let mut reversed = opt.program.graph.clone(); + reversed.reverse(); + PostDominators(dominators::simple_fast(&reversed, opt.ret)) + } +} +/// Dominance frontiers for each block pub struct DomFrontiers { /// The dominance frontiers of each block (where phi nodes must be inserted). dom_frontiers: HashMap>, diff --git a/crates/cubecl-opt/src/analyses/dominators.rs b/crates/cubecl-opt/src/analyses/dominators.rs deleted file mode 100644 index b2d43618f..000000000 --- a/crates/cubecl-opt/src/analyses/dominators.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::ops::Deref; - -use crate::NodeIndex; -use petgraph::algo::dominators; - -use super::Analysis; - -pub struct Dominators(dominators::Dominators); -pub struct PostDominators(dominators::Dominators); - -impl Deref for Dominators { - type Target = dominators::Dominators; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Deref for PostDominators { - type Target = dominators::Dominators; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Analysis for Dominators { - fn init(opt: &mut crate::Optimizer) -> Self { - Dominators(dominators::simple_fast(&opt.program.graph, opt.entry())) - } -} - -impl Analysis for PostDominators { - fn init(opt: &mut crate::Optimizer) -> Self { - let mut reversed = opt.program.graph.clone(); - reversed.reverse(); - PostDominators(dominators::simple_fast(&reversed, opt.ret)) - } -} diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs index 63ebdf924..19a77ff2e 100644 --- a/crates/cubecl-opt/src/analyses/mod.rs +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -1,6 +1,5 @@ mod base; -pub mod dominance_frontiers; -pub mod dominators; +pub mod dominance; pub mod liveness; pub mod post_order; pub mod writes; diff --git a/crates/cubecl-opt/src/gvn/analysis.rs b/crates/cubecl-opt/src/gvn/analysis.rs index 9af5cd24e..4aeb456d3 100644 --- a/crates/cubecl-opt/src/gvn/analysis.rs +++ b/crates/cubecl-opt/src/gvn/analysis.rs @@ -7,7 +7,7 @@ use crate::{analyses::Analysis, NodeIndex}; use smallvec::SmallVec; use crate::{ - analyses::dominators::{Dominators, PostDominators}, + analyses::dominance::{Dominators, PostDominators}, Optimizer, }; diff --git a/crates/cubecl-opt/src/gvn/apply.rs b/crates/cubecl-opt/src/gvn/apply.rs index a6c2baf3a..f02333802 100644 --- a/crates/cubecl-opt/src/gvn/apply.rs +++ b/crates/cubecl-opt/src/gvn/apply.rs @@ -4,7 +4,7 @@ use cubecl_core::ir::{self, Operation}; use petgraph::graph::NodeIndex; use crate::{ - analyses::dominators::Dominators, + analyses::dominance::Dominators, gvn::{convert::value_of_var, phi_translate}, version::PhiEntry, AtomicCounter, Optimizer, PhiInstruction, diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index b8b767fd5..d47dbda10 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -30,7 +30,7 @@ use std::{ sync::atomic::{AtomicUsize, Ordering}, }; -use analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness, writes::Writes, Analyses}; +use analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes, Analyses}; use cubecl_core::{ ir::{self as core, Branch, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -144,6 +144,8 @@ struct Range { pub struct Optimizer { /// The overall program state program: Program, + /// Analyses with persistent state + analyses: Rc, /// The current block while parsing current_block: Option, /// The current loop's break target @@ -156,8 +158,6 @@ pub struct Optimizer { pub(crate) cube_dim: CubeDim, /// The execution mode, `Unchecked` skips bounds check optimizations. pub(crate) mode: ExecutionMode, - - analyses: Rc, } impl Default for Optimizer { diff --git a/crates/cubecl-opt/src/phi_frontiers.rs b/crates/cubecl-opt/src/phi_frontiers.rs index dbd2081b7..fdbdb173b 100644 --- a/crates/cubecl-opt/src/phi_frontiers.rs +++ b/crates/cubecl-opt/src/phi_frontiers.rs @@ -2,7 +2,7 @@ use cubecl_core::ir::{Item, Variable, VariableKind}; use petgraph::graph::NodeIndex; use crate::{ - analyses::{dominance_frontiers::DomFrontiers, liveness::Liveness, writes::Writes}, + analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes}, Optimizer, }; From d9f8a2f71e1d820327d84ad4523bc8c10ee39d9c Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 7 Jan 2025 23:20:25 +0100 Subject: [PATCH 09/17] Migrate const length and range analyses --- crates/cubecl-opt/src/analyses/const_len.rs | 96 ++++++++ .../integer_range.rs} | 211 +++++++++--------- crates/cubecl-opt/src/analyses/mod.rs | 2 + crates/cubecl-opt/src/debug.rs | 11 +- crates/cubecl-opt/src/lib.rs | 54 +---- crates/cubecl-opt/src/passes/constant_prop.rs | 9 +- .../src/passes/find_safe_indexes.rs | 65 ++++++ .../src/passes/in_bounds_analysis.rs | 96 -------- crates/cubecl-opt/src/passes/mod.rs | 6 +- 9 files changed, 291 insertions(+), 259 deletions(-) create mode 100644 crates/cubecl-opt/src/analyses/const_len.rs rename crates/cubecl-opt/src/{passes/integer_range_analysis.rs => analyses/integer_range.rs} (50%) create mode 100644 crates/cubecl-opt/src/passes/find_safe_indexes.rs delete mode 100644 crates/cubecl-opt/src/passes/in_bounds_analysis.rs diff --git a/crates/cubecl-opt/src/analyses/const_len.rs b/crates/cubecl-opt/src/analyses/const_len.rs new file mode 100644 index 000000000..d32e77101 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/const_len.rs @@ -0,0 +1,96 @@ +use std::{collections::HashMap, ops::Deref}; + +use cubecl_core::ir::{Operation, Operator, Variable, VariableKind}; + +use crate::Optimizer; + +use super::Analysis; + +#[derive(Debug, Clone)] +pub struct Slice { + pub start: Variable, + pub end: Variable, + pub end_op: Option, + pub const_len: Option, +} + +/// Try to find any constant length slices by cancelling common factors in `start` and `end` +#[derive(Default, Debug)] +pub struct Slices { + slices: HashMap<(u16, u8), Slice>, +} + +impl Deref for Slices { + type Target = HashMap<(u16, u8), Slice>; + + fn deref(&self) -> &Self::Target { + &self.slices + } +} + +impl Analysis for Slices { + fn init(opt: &mut Optimizer) -> Self { + let mut this = Slices::default(); + this.populate_slices(opt); + this.find_end_ops(opt); + this + } +} + +impl Slices { + fn populate_slices(&mut self, opt: &mut Optimizer) { + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for operator in ops.borrow().values() { + let op = match &operator.operation { + Operation::Operator(op) => op, + _ => continue, + }; + let out = operator.out.as_ref(); + if let Operator::Slice(slice_op) = op { + let out_id = match out.unwrap().kind { + VariableKind::Slice { id, depth } => (id, depth), + _ => unreachable!(), + }; + let const_len = slice_op.start.as_const().zip(slice_op.end.as_const()); + let const_len = const_len.map(|(start, end)| end.as_u32() - start.as_u32()); + self.slices.insert( + out_id, + Slice { + start: slice_op.start, + end: slice_op.end, + end_op: None, + const_len, + }, + ); + }; + } + } + } + + fn find_end_ops(&mut self, opt: &mut Optimizer) { + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for operator in ops.borrow().values() { + let op = match &operator.operation { + Operation::Operator(op) => op, + _ => continue, + }; + // Only handle the simplest cases for now + if let Operator::Add(op) = op { + let mut slices = self.slices.values_mut(); + let slice = + slices.find(|it| it.end == operator.out() && it.const_len.is_none()); + if let Some(slice) = slice { + slice.end_op = Some(Operator::Add(op.clone()).into()); + if op.lhs == slice.start && op.rhs.as_const().is_some() { + slice.const_len = Some(op.rhs.as_const().unwrap().as_u32()); + } else if op.rhs == slice.start && op.lhs.as_const().is_some() { + slice.const_len = Some(op.lhs.as_const().unwrap().as_u32()); + } + } + }; + } + } + } +} diff --git a/crates/cubecl-opt/src/passes/integer_range_analysis.rs b/crates/cubecl-opt/src/analyses/integer_range.rs similarity index 50% rename from crates/cubecl-opt/src/passes/integer_range_analysis.rs rename to crates/cubecl-opt/src/analyses/integer_range.rs index 2eced266e..b28f2b1dd 100644 --- a/crates/cubecl-opt/src/passes/integer_range_analysis.rs +++ b/crates/cubecl-opt/src/analyses/integer_range.rs @@ -1,22 +1,41 @@ -use std::ops::{Add, Mul, Sub}; +use std::{ + collections::HashMap, + ops::{Add, Mul, Sub}, +}; use cubecl_core::ir::{ - Builtin, ConstantScalarValue, Elem, Operation, Operator, UIntKind, Variable, VariableKind, + Builtin, ConstantScalarValue, Elem, Operation, Operator, Variable, VariableKind, }; -use crate::{AtomicCounter, Optimizer, Range}; +use crate::{Optimizer, VarId}; + +use super::Analysis; + +#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)] +pub struct Range { + pub lower_bound: Option, + pub upper_bound: Option, +} -use super::OptimizerPass; +/// Perform analysis on the possible ranges of integer values and store the results for use in later +/// optimization passes. Reasons for integers being bounded but not constant might be: the modulo +/// operator (bounds it to `0..m`), or `UNIT_POS` (bounded by `CubeDim`). Bounds can be transferred +/// between simple arithmetic, so we can determine the possible range of a good number of variables. +/// This is currently only used in index bound analysis. +#[derive(Debug, Default)] +pub struct Ranges { + int_ranges: HashMap, +} impl Range { - fn constant(val: i64) -> Self { + fn constant(val: u64) -> Self { Self { lower_bound: Some(val), upper_bound: Some(val), } } - fn uint(upper: i64) -> Self { + fn uint(upper: u64) -> Self { Self { lower_bound: Some(0), upper_bound: Some(upper), @@ -24,16 +43,17 @@ impl Range { } } -/// Perform analysis on the possible ranges of integer values and store the results for use in later -/// optimization passes. Reasons for integers being bounded but not constant might be: the modulo -/// operator (bounds it to `0..m`), or `UNIT_POS` (bounded by `CubeDim`). Bounds can be transferred -/// between simple arithmetic, so we can determine the possible range of a good number of variables. -/// This is currently only used in index bound analysis. -#[derive(Default, Clone, Debug)] -pub struct IntegerRangeAnalysis; +impl Analysis for Ranges { + fn init(opt: &mut Optimizer) -> Self { + let mut this = Ranges::default(); + // Run fixed point iteration + while this.run_loop(opt) {} + this + } +} -impl OptimizerPass for IntegerRangeAnalysis { - fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { +impl Ranges { + fn run_loop(&mut self, opt: &mut Optimizer) -> bool { for block in opt.node_ids() { let ops = opt.program[block].ops.clone(); for inst in ops.borrow().values() { @@ -42,58 +62,58 @@ impl OptimizerPass for IntegerRangeAnalysis { _ => continue, }; match op { - Operator::Add(binop) if inst.item().elem().is_int() => { + Operator::Add(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range + rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Sub(binop) if inst.item().elem().is_int() => { + Operator::Sub(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range - rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Mul(binop) if inst.item().elem().is_int() => { + Operator::Mul(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range * rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Div(binop) if inst.item().elem().is_int() => { + Operator::Div(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range: Range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range / rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Modulo(binop) if inst.item().elem().is_int() => { + Operator::Modulo(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range % rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } @@ -101,63 +121,59 @@ impl OptimizerPass for IntegerRangeAnalysis { } } } + false } } -/// The possible range of values of any variable, if applicable. Returns unbounded range if no range -/// can be determined, or the type is not an integer. -pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { - match var.kind { - VariableKind::Versioned { id, depth, version } - if var.item.elem() == Elem::UInt(UIntKind::U32) => - { - opt.program +fn is_uint(elem: Elem) -> bool { + matches!(elem, Elem::UInt(_)) +} + +impl Ranges { + /// The possible range of values of any variable, if applicable. Returns unbounded range if no range + /// can be determined, or the type is not an integer. + pub fn range_of(&self, opt: &Optimizer, var: &Variable) -> Range { + match var.kind { + VariableKind::Versioned { id, depth, version } if is_uint(var.item.elem()) => self .int_ranges .get(&(id, depth, version)) .copied() .unwrap_or(Range { lower_bound: Some(0), upper_bound: None, - }) - } - VariableKind::Versioned { id, depth, version } => opt - .program - .int_ranges - .get(&(id, depth, version)) - .copied() - .unwrap_or_default(), - VariableKind::LocalConst { id, depth } if var.item.elem() == Elem::UInt(UIntKind::U32) => { - opt.program + }), + VariableKind::Versioned { id, depth, version } => self + .int_ranges + .get(&(id, depth, version)) + .copied() + .unwrap_or_default(), + VariableKind::LocalConst { id, depth } if is_uint(var.item.elem()) => self .int_ranges .get(&(id, depth, 0)) .copied() .unwrap_or(Range { lower_bound: Some(0), upper_bound: None, - }) - } - VariableKind::LocalConst { id, depth } => opt - .program - .int_ranges - .get(&(id, depth, 0)) - .copied() - .unwrap_or_default(), - VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => Range::constant(val), - VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => { - Range::constant(val as i64) - } - VariableKind::Builtin(builtin) => match builtin { - Builtin::UnitPos => Range::uint(opt.cube_dim.num_elems() as i64 - 1), - Builtin::UnitPosX => Range::uint(opt.cube_dim.x as i64 - 1), - Builtin::UnitPosY => Range::uint(opt.cube_dim.y as i64 - 1), - Builtin::UnitPosZ => Range::uint(opt.cube_dim.z as i64 - 1), - Builtin::CubeCount => Range::constant(opt.cube_dim.num_elems() as i64), - Builtin::CubeCountX => Range::constant(opt.cube_dim.x as i64), - Builtin::CubeCountY => Range::constant(opt.cube_dim.y as i64), - Builtin::CubeCountZ => Range::constant(opt.cube_dim.z as i64), + }), + VariableKind::LocalConst { id, depth } => self + .int_ranges + .get(&(id, depth, 0)) + .copied() + .unwrap_or_default(), + VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => Range::constant(val), + VariableKind::Builtin(builtin) => match builtin { + Builtin::UnitPos => Range::uint(opt.cube_dim.num_elems() as u64 - 1), + Builtin::UnitPosX => Range::uint(opt.cube_dim.x as u64 - 1), + Builtin::UnitPosY => Range::uint(opt.cube_dim.y as u64 - 1), + Builtin::UnitPosZ => Range::uint(opt.cube_dim.z as u64 - 1), + Builtin::CubeCount => Range::constant(opt.cube_dim.num_elems() as u64), + Builtin::CubeCountX => Range::constant(opt.cube_dim.x as u64), + Builtin::CubeCountY => Range::constant(opt.cube_dim.y as u64), + Builtin::CubeCountZ => Range::constant(opt.cube_dim.z as u64), + _ => Default::default(), + }, _ => Default::default(), - }, - _ => Default::default(), + } } } @@ -233,28 +249,13 @@ mod range_ops { type Output = Range; fn rem(self, rhs: Self) -> Self::Output { - let min_neg = self.lower_bound.map(|it| it < 0).unwrap_or(true); - let max_pos = self.upper_bound.map(|it| it > 0).unwrap_or(true); if rhs.lower_bound.is_none() || rhs.upper_bound.is_none() { return self; } - let rhs_lower = rhs.lower_bound.unwrap().abs(); - let rhs_upper = rhs.upper_bound.unwrap().abs(); - let rhs_max = rhs_lower.max(rhs_upper); - match (min_neg, max_pos) { - (true, false) => Range { - lower_bound: Some(-(rhs_max - 1)), - upper_bound: Some(0), - }, - (true, true) => Range { - lower_bound: Some(-(rhs_max - 1)), - upper_bound: Some(rhs_max - 1), - }, - (false, true) => Range { - lower_bound: Some(0), - upper_bound: Some(rhs_max - 1), - }, - _ => self, + let rhs_upper = rhs.upper_bound.unwrap(); + Range { + lower_bound: Some(0), + upper_bound: Some(rhs_upper - 1), } } } diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs index 19a77ff2e..654c1d2c5 100644 --- a/crates/cubecl-opt/src/analyses/mod.rs +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -1,5 +1,7 @@ mod base; +pub mod const_len; pub mod dominance; +pub mod integer_range; pub mod liveness; pub mod post_order; pub mod writes; diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index e956f44a6..8c092c091 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -4,9 +4,8 @@ use cubecl_core::ir::{FloatKind, IntKind, UIntKind}; use petgraph::visit::EdgeRef; use crate::{ - analyses::liveness::Liveness, + analyses::{const_len::Slices, integer_range::Ranges, liveness::Liveness}, gvn::{BlockSets, Constant, Expression, GvnState, Instruction, Local, OpId, Value, ValueTable}, - passes::var_id, ControlFlow, }; @@ -17,8 +16,11 @@ const DEBUG_GVN: bool = false; /// Debug display for the program state. impl Display for Optimizer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let slices = self.analyses.try_get::().unwrap_or_default(); + let ranges = self.analyses.try_get::().unwrap_or_default(); + f.write_str("Slices:\n")?; - for (var_id, slice) in self.program.slices.iter() { + for (var_id, slice) in slices.iter() { let end_op = slice.end_op.as_ref().map(|it| format!("{it}")); writeln!( f, @@ -76,8 +78,7 @@ impl Display for Optimizer { } for op in bb.ops.borrow_mut().values_mut() { - let id = op.out.and_then(|var| var_id(&var)); - let range = id.and_then(|id| self.program.int_ranges.get(&id)); + let range = op.out.map(|var| ranges.range_of(self, &var)); let range = range.map(|it| format!(" range: {it};")).unwrap_or_default(); writeln!(f, " {op};{range}")?; diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 1fed23d90..c073b9263 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -43,9 +43,8 @@ use gvn::GvnPass; use passes::{ CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform, EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables, - EmptyBranchToSelect, FindConstSliceLen, InBoundsToUnchecked, InlineAssignments, - IntegerRangeAnalysis, MergeBlocks, MergeSameExpressions, OptimizerPass, ReduceStrength, - RemoveIndexScalar, + EmptyBranchToSelect, InBoundsToUnchecked, InlineAssignments, MergeBlocks, MergeSameExpressions, + OptimizerPass, ReduceStrength, RemoveIndexScalar, }; use petgraph::{prelude::StableDiGraph, visit::EdgeRef, Direction}; @@ -89,14 +88,6 @@ impl AtomicCounter { } } -#[derive(Debug, Clone)] -pub(crate) struct Slice { - pub(crate) start: Variable, - pub(crate) end: Variable, - pub(crate) end_op: Option, - pub(crate) const_len: Option, -} - #[derive(Debug, Clone)] pub struct ConstArray { pub id: u16, @@ -109,10 +100,8 @@ pub struct ConstArray { struct Program { pub const_arrays: Vec, pub variables: HashMap<(u16, u8), Item>, - pub(crate) slices: HashMap<(u16, u8), Slice>, pub graph: StableDiGraph, root: NodeIndex, - int_ranges: HashMap, temp_id: AtomicCounter, } @@ -133,12 +122,6 @@ impl DerefMut for Program { type VarId = (u16, u8, u16); -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)] -struct Range { - lower_bound: Option, - upper_bound: Option, -} - /// An optimizer that applies various analyses and optimization passes to the IR. #[derive(Debug, Clone)] pub struct Optimizer { @@ -270,15 +253,6 @@ impl Optimizer { Box::new(EliminateDeadBlocks), Box::new(EliminateDeadPhi), ]; - // Passes that only run if execution mode is checked - let checked_passes: Vec> = vec![ - Box::new(IntegerRangeAnalysis), - Box::new(FindConstSliceLen), - Box::new(InBoundsToUnchecked), - ]; - if matches!(self.mode, ExecutionMode::Checked) { - passes.extend(checked_passes); - } loop { let counter = AtomicCounter::default(); @@ -290,6 +264,11 @@ impl Optimizer { break; } } + + // Only replace indexing when checked, since all indexes are unchecked anyways in unchecked + if matches!(self.mode, ExecutionMode::Checked) { + InBoundsToUnchecked.apply_post_ssa(self, AtomicCounter::new(0)); + } } /// Remove non-constant index vectors from SSA transformation because they currently must be @@ -378,27 +357,8 @@ impl Optimizer { let is_break = processed.operations.contains(&Branch::Break.into()); for mut instruction in processed.operations { - let out = instruction.out; match &mut instruction.operation { Operation::Branch(branch) => self.parse_control_flow(branch.clone()), - Operation::Operator(Operator::Slice(slice_op)) => { - let out_id = match out.unwrap().kind { - VariableKind::Slice { id, depth } => (id, depth), - _ => unreachable!(), - }; - let const_len = slice_op.start.as_const().zip(slice_op.end.as_const()); - let const_len = const_len.map(|(start, end)| end.as_u32() - start.as_u32()); - self.program.slices.insert( - out_id, - Slice { - start: slice_op.start, - end: slice_op.end, - end_op: None, - const_len, - }, - ); - self.current_block_mut().ops.borrow_mut().push(instruction); - } _ => { self.current_block_mut().ops.borrow_mut().push(instruction); } diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index cbc49e6c9..e48a692de 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -3,7 +3,10 @@ use cubecl_core::ir::{ VariableKind, }; -use crate::{AtomicCounter, Optimizer, Slice}; +use crate::{ + analyses::const_len::{Slice, Slices}, + AtomicCounter, Optimizer, +}; use super::OptimizerPass; @@ -13,6 +16,8 @@ pub struct ConstOperandSimplify; impl OptimizerPass for ConstOperandSimplify { fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { + let slices = opt.analysis::(); + for node in opt.program.node_indices().collect::>() { let ops = opt.program[node].ops.borrow().indices().collect::>(); @@ -123,7 +128,7 @@ impl OptimizerPass for ConstOperandSimplify { changes.inc(); } VariableKind::Slice { id, depth } => { - let slice = opt.program.slices.get(&(id, depth)); + let slice = slices.get(&(id, depth)); if let Some(Slice { const_len: Some(len), .. diff --git a/crates/cubecl-opt/src/passes/find_safe_indexes.rs b/crates/cubecl-opt/src/passes/find_safe_indexes.rs new file mode 100644 index 000000000..9e9165c26 --- /dev/null +++ b/crates/cubecl-opt/src/passes/find_safe_indexes.rs @@ -0,0 +1,65 @@ +use cubecl_core::ir::{Operation, Operator, Variable, VariableKind}; + +use crate::{ + analyses::{const_len::Slices, integer_range::Ranges}, + AtomicCounter, Optimizer, +}; + +use super::OptimizerPass; + +/// Use the results from integer range analysis to find indexes that are always in bounds, then +/// transform them to unchecked indexes. +pub struct InBoundsToUnchecked; + +impl OptimizerPass for InBoundsToUnchecked { + fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { + let ranges = opt.analysis::(); + + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for inst in ops.borrow_mut().values_mut() { + let op = match &inst.operation { + Operation::Operator(op) => op, + _ => continue, + }; + match op { + Operator::Index(op) => { + if let Some(const_len) = const_len(opt, &op.lhs) { + let range = ranges.range_of(opt, &op.rhs); + if let Some((_, upper)) = range.lower_bound.zip(range.upper_bound) { + if (upper as u32) < const_len { + inst.operation = Operator::UncheckedIndex(op.clone()).into(); + changes.inc(); + } + } + } + } + Operator::IndexAssign(op) => { + if let Some(const_len) = const_len(opt, &inst.out()) { + let range = ranges.range_of(opt, &op.lhs); + if let Some((_, upper)) = range.lower_bound.zip(range.upper_bound) { + if (upper as u32) < const_len { + inst.operation = + Operator::UncheckedIndexAssign(op.clone()).into(); + changes.inc(); + } + } + } + } + _ => {} + } + } + } + } +} + +fn const_len(opt: &mut Optimizer, var: &Variable) -> Option { + let slices = opt.analysis::(); + match var.kind { + VariableKind::ConstantArray { length, .. } => Some(length), + VariableKind::SharedMemory { length, .. } => Some(length), + VariableKind::LocalArray { length, .. } => Some(length), + VariableKind::Slice { id, depth } => slices.get(&(id, depth)).and_then(|it| it.const_len), + _ => None, + } +} diff --git a/crates/cubecl-opt/src/passes/in_bounds_analysis.rs b/crates/cubecl-opt/src/passes/in_bounds_analysis.rs deleted file mode 100644 index 2607c13a9..000000000 --- a/crates/cubecl-opt/src/passes/in_bounds_analysis.rs +++ /dev/null @@ -1,96 +0,0 @@ -use cubecl_core::ir::{Operation, Operator, Variable, VariableKind}; - -use crate::{AtomicCounter, Optimizer}; - -use super::{range_of, OptimizerPass}; - -/// Try to find any constant length slices by cancelling common factors in `start` and `end` -pub struct FindConstSliceLen; - -impl OptimizerPass for FindConstSliceLen { - fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { - for block in opt.node_ids() { - let ops = opt.program[block].ops.clone(); - for operator in ops.borrow().values() { - let op = match &operator.operation { - Operation::Operator(op) => op, - _ => continue, - }; - // Only handle the simplest cases for now - if let Operator::Add(op) = op { - let mut slices = opt.program.slices.values_mut(); - let slice = - slices.find(|it| it.end == operator.out() && it.const_len.is_none()); - if let Some(slice) = slice { - slice.end_op = Some(Operator::Add(op.clone()).into()); - if op.lhs == slice.start && op.rhs.as_const().is_some() { - slice.const_len = Some(op.rhs.as_const().unwrap().as_u32()); - changes.inc(); - } else if op.rhs == slice.start && op.lhs.as_const().is_some() { - slice.const_len = Some(op.lhs.as_const().unwrap().as_u32()); - changes.inc(); - } - } - } - } - } - } -} - -/// Use the results from integer range analysis to find indexes that are always in bounds, then -/// transform them to unchecked indexes. -pub struct InBoundsToUnchecked; - -impl OptimizerPass for InBoundsToUnchecked { - fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { - for block in opt.node_ids() { - let ops = opt.program[block].ops.clone(); - for inst in ops.borrow_mut().values_mut() { - let op = match &inst.operation { - Operation::Operator(op) => op, - _ => continue, - }; - match op { - Operator::Index(op) => { - if let Some(const_len) = const_len(opt, &op.lhs) { - let range = range_of(opt, &op.rhs); - if let Some((lower, upper)) = range.lower_bound.zip(range.upper_bound) { - if lower >= 0 && (upper as u32) < const_len { - inst.operation = Operator::UncheckedIndex(op.clone()).into(); - changes.inc(); - } - } - } - } - Operator::IndexAssign(op) => { - if let Some(const_len) = const_len(opt, &inst.out()) { - let range = range_of(opt, &op.lhs); - if let Some((lower, upper)) = range.lower_bound.zip(range.upper_bound) { - if lower >= 0 && (upper as u32) < const_len { - inst.operation = - Operator::UncheckedIndexAssign(op.clone()).into(); - changes.inc(); - } - } - } - } - _ => {} - } - } - } - } -} - -fn const_len(opt: &Optimizer, var: &Variable) -> Option { - match var.kind { - VariableKind::ConstantArray { length, .. } => Some(length), - VariableKind::SharedMemory { length, .. } => Some(length), - VariableKind::LocalArray { length, .. } => Some(length), - VariableKind::Slice { id, depth } => opt - .program - .slices - .get(&(id, depth)) - .and_then(|it| it.const_len), - _ => None, - } -} diff --git a/crates/cubecl-opt/src/passes/mod.rs b/crates/cubecl-opt/src/passes/mod.rs index 197fd3c46..a14720d3d 100644 --- a/crates/cubecl-opt/src/passes/mod.rs +++ b/crates/cubecl-opt/src/passes/mod.rs @@ -3,10 +3,9 @@ mod composite; mod constant_prop; mod dead_code; mod expression_merge; -mod in_bounds_analysis; +mod find_safe_indexes; mod index_merge; mod inlined_if_to_select; -mod integer_range_analysis; mod reduce_strength; pub use array_copy_propagate::*; @@ -14,10 +13,9 @@ pub use composite::*; pub use constant_prop::*; pub use dead_code::*; pub use expression_merge::*; -pub use in_bounds_analysis::*; +pub use find_safe_indexes::*; pub use index_merge::*; pub use inlined_if_to_select::*; -pub use integer_range_analysis::*; pub use reduce_strength::*; use crate::AtomicCounter; From 688c2e2ccd3b2cce39e23d6a2c5ce873eb702525 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 7 Jan 2025 23:26:45 +0100 Subject: [PATCH 10/17] Remove erroneous capability --- crates/cubecl-spirv/src/instruction.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/cubecl-spirv/src/instruction.rs b/crates/cubecl-spirv/src/instruction.rs index e3e7de615..798a0b3d1 100644 --- a/crates/cubecl-spirv/src/instruction.rs +++ b/crates/cubecl-spirv/src/instruction.rs @@ -440,7 +440,6 @@ impl SpirvCompiler { }); } Operator::ReverseBits(op) => { - self.capabilities.insert(Capability::BitInstructions); self.compile_unary_op(op, out, |b, _, ty, input, out| { b.bit_reverse(ty, Some(out), input).unwrap(); }); From 72a158b63e477f05537de4922a8fb265ecb314f5 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:26:35 +0100 Subject: [PATCH 11/17] Fix for memory leak --- .../src/matmul/tests/cmma_matmul/matmul_test_launcher.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs index 0519d967b..551073535 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs @@ -104,6 +104,7 @@ pub fn test_matmul_algorithm( if A::check_availability::(&client, &config).is_err() { // Can't execute the test. println!("Skipped - not supported!"); + client.flush(); return; } From b343a5a6a83812e95727ff8204467c736fdd777b Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:09:20 +0100 Subject: [PATCH 12/17] Fix local creation --- crates/cubecl-opt/src/gvn/apply.rs | 4 ++-- crates/cubecl-opt/src/lib.rs | 2 ++ crates/cubecl-opt/src/passes/reduce_strength.rs | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/crates/cubecl-opt/src/gvn/apply.rs b/crates/cubecl-opt/src/gvn/apply.rs index ae79af31a..7c4cd8140 100644 --- a/crates/cubecl-opt/src/gvn/apply.rs +++ b/crates/cubecl-opt/src/gvn/apply.rs @@ -83,7 +83,7 @@ impl GvnState { } let leaders = &mut self.block_sets.get_mut(&pred).unwrap().leaders; if !leaders.contains_key(&val) { - let new_temp = *opt.allocator.create_local_restricted(expr.item()); + let new_temp = *opt.allocator.create_local(expr.item()); let new_op = ir::Instruction::new(expr.to_operation(leaders), new_temp); opt.program[pred].ops.borrow_mut().push(new_op); leaders.insert(val, value_of_var(&new_temp).unwrap()); @@ -102,7 +102,7 @@ impl GvnState { let new_phis = new_phis .into_iter() .map(|entries| PhiInstruction { - out: *opt.allocator.create_local_restricted(entries[0].value.item), + out: *opt.allocator.create_local(entries[0].value.item), entries, }) .collect::>(); diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 4543f9edd..927e04af4 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -195,6 +195,8 @@ impl Optimizer { self.apply_post_ssa_passes(); } + println!("{self}"); + let gvn_count = AtomicCounter::new(0); GvnPass.apply_post_ssa(self, gvn_count.clone()); ReduceStrength.apply_post_ssa(self, gvn_count.clone()); diff --git a/crates/cubecl-opt/src/passes/reduce_strength.rs b/crates/cubecl-opt/src/passes/reduce_strength.rs index 2fea2b013..88795b2f8 100644 --- a/crates/cubecl-opt/src/passes/reduce_strength.rs +++ b/crates/cubecl-opt/src/passes/reduce_strength.rs @@ -56,7 +56,7 @@ impl OptimizerPass for ReduceStrength { changes.inc(); } val if (val + 1).is_power_of_two() => { - let temp = *opt.allocator.create_local_restricted(inst.item()); + let temp = *opt.allocator.create_local(inst.item()); new_ops.push(Instruction::new( Operator::ShiftLeft(BinaryOperator { lhs: dyn_val, @@ -74,7 +74,7 @@ impl OptimizerPass for ReduceStrength { changes.inc(); } val if (val - 1).is_power_of_two() => { - let temp = *opt.allocator.create_local_restricted(inst.item()); + let temp = *opt.allocator.create_local(inst.item()); new_ops.push(Instruction::new( Operator::ShiftLeft(BinaryOperator { lhs: dyn_val, From 05e54fb5caae8816515eeaf2087b7a7ab16d9e31 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:55:29 +0100 Subject: [PATCH 13/17] Remove leftover println --- crates/cubecl-opt/src/lib.rs | 2 -- crates/cubecl-wgpu/src/lib.rs | 7 ++++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 927e04af4..4543f9edd 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -195,8 +195,6 @@ impl Optimizer { self.apply_post_ssa_passes(); } - println!("{self}"); - let gvn_count = AtomicCounter::new(0); GvnPass.apply_post_ssa(self, gvn_count.clone()); ReduceStrength.apply_post_ssa(self, gvn_count.clone()); diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 50dea369c..d2d390d55 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -39,8 +39,9 @@ mod tests_spirv { use half::f16; cubecl_core::testgen_all!(f32: [f16, flex32, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); - cubecl_linalg::testgen_matmul_plane!([f16, flex32, f32]); - cubecl_linalg::testgen_matmul_tiling2d!([f16, flex32, f32, f64]); - cubecl_linalg::testgen_matmul_simple!([flex32, f32]); + cubecl_linalg::testgen_matmul_plane!([f16, f32]); + cubecl_linalg::testgen_matmul_tiling2d!([f16, f32, f64]); + cubecl_linalg::testgen_matmul_simple!([f32]); cubecl_linalg::testgen_matmul_accelerated!([f16]); + cubecl_reduce::testgen_reduce!(); } From 855157e86a0105cccebbbb950f55e60fe06b8c59 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 9 Jan 2025 17:14:12 +0100 Subject: [PATCH 14/17] Fix lints --- crates/cubecl-cuda/src/lib.rs | 1 + crates/cubecl-macros/src/parse/helpers.rs | 6 +----- crates/cubecl-opt/src/gvn/analysis.rs | 6 +----- crates/cubecl-wgpu/src/lib.rs | 2 ++ 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index 8b1e36e0b..1fb40d789 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -10,6 +10,7 @@ pub use device::*; pub use runtime::*; #[cfg(test)] +#[allow(unexpected_cfgs)] mod tests { pub type TestRuntime = crate::CudaRuntime; pub use half::{bf16, f16}; diff --git a/crates/cubecl-macros/src/parse/helpers.rs b/crates/cubecl-macros/src/parse/helpers.rs index 88f0ee9d1..83477e829 100644 --- a/crates/cubecl-macros/src/parse/helpers.rs +++ b/crates/cubecl-macros/src/parse/helpers.rs @@ -51,11 +51,7 @@ impl Unroll { pub value: Expr, } - let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll")); - let attr = match attr { - Some(attr) => attr, - None => return None, - }; + let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll"))?; match &attr.meta { syn::Meta::Path(_) => None, diff --git a/crates/cubecl-opt/src/gvn/analysis.rs b/crates/cubecl-opt/src/gvn/analysis.rs index 4aeb456d3..c3bec2255 100644 --- a/crates/cubecl-opt/src/gvn/analysis.rs +++ b/crates/cubecl-opt/src/gvn/analysis.rs @@ -166,11 +166,7 @@ impl GvnState { .map(|child| &self.block_sets[child].antic_in); // Only add expressions expected at all successors to this block's anticipated list for (val, expr) in potential_out { - if rest - .clone() - .map(|child| child.iter().any(|v| v.0 == *val)) - .all(|b| b) - { + if rest.clone().all(|child| child.iter().any(|v| v.0 == *val)) { result.push_back((*val, expr.clone())); } } diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index d2d390d55..bd447d1d0 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -21,6 +21,7 @@ pub use runtime::*; pub use compiler::spirv; #[cfg(test)] +#[allow(unexpected_cfgs)] mod tests { pub type TestRuntime = crate::WgpuRuntime; @@ -33,6 +34,7 @@ mod tests { } #[cfg(all(test, feature = "spirv"))] +#[allow(unexpected_cfgs)] mod tests_spirv { pub type TestRuntime = crate::WgpuRuntime; use cubecl_core::flex32; From 6a4db895bd1ec5450396bc16cf1d0084e53ed3c4 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 9 Jan 2025 23:50:02 +0100 Subject: [PATCH 15/17] Don't eliminate impure operations when looking for dead code --- crates/cubecl-core/src/ir/operation.rs | 21 +++++++++++++++++++++ crates/cubecl-opt/src/passes/dead_code.rs | 5 +++++ 2 files changed, 26 insertions(+) diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index acdf7febc..79e188a68 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -54,6 +54,27 @@ impl Instruction { } } +impl Operation { + /// Whether this operation is pure, aka has no side effects. Pure operations can be removed + /// if their output is not needed, impure operations must be kept since their execution can + /// affect things down the line. e.g. atomics. + /// + /// Operations that operate across multiple units are always considered impure. + pub fn is_pure(&self) -> bool { + match self { + Operation::Copy(_) => true, + Operation::Operator(_) => true, + Operation::Atomic(_) => false, + Operation::Metadata(_) => true, + Operation::Branch(_) => false, + Operation::Synchronization(_) => false, + Operation::Plane(_) => false, + Operation::CoopMma(_) => false, + Operation::NonSemantic(_) => false, + } + } +} + impl Display for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.operation { diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index 40d0f1da0..db0bd86c9 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -31,6 +31,11 @@ fn search_loop(opt: &mut Optimizer) -> bool { for idx in ops { let mut op = opt.program[node].ops.borrow()[idx].clone(); + // Assume operations and metadata are pure, and everything else might have side effects + // Technically not correct but much simpler than + if !op.operation.is_pure() { + continue; + } let mut out = None; let used = Rc::new(AtomicBool::new(false)); opt.visit_out(&mut op.out, |_, var| { From efaea0334de2c17e77580a1a5b651e9d187f99b8 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:06:51 +0100 Subject: [PATCH 16/17] Fix comment --- crates/cubecl-opt/src/passes/dead_code.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index db0bd86c9..6ac3ee22d 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -31,8 +31,8 @@ fn search_loop(opt: &mut Optimizer) -> bool { for idx in ops { let mut op = opt.program[node].ops.borrow()[idx].clone(); - // Assume operations and metadata are pure, and everything else might have side effects - // Technically not correct but much simpler than + // Impure operations must be skipped because they can change things even if the output + // is unused if !op.operation.is_pure() { continue; } From f49e5dca815b02a95aa6afd7e2c5e71d8a38b17e Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sat, 11 Jan 2025 20:12:38 +0100 Subject: [PATCH 17/17] Fixup --- crates/cubecl-opt/src/analyses/base.rs | 12 ++++++++---- crates/cubecl-opt/src/debug.rs | 11 +++++++---- crates/cubecl-opt/src/lib.rs | 8 +++++--- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs index 93a05e0d1..e705725c1 100644 --- a/crates/cubecl-opt/src/analyses/base.rs +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -18,11 +18,11 @@ pub trait Analysis { } #[derive(Default, Clone, Debug)] -pub struct Analyses { +pub struct AnalysisCache { cache: Rc>, } -impl Analyses { +impl AnalysisCache { pub fn get(&self, opt: &mut Optimizer) -> Rc { let analysis = self.cache.borrow().get::>().cloned(); if let Some(analysis) = analysis { @@ -44,15 +44,19 @@ impl Analyses { } impl Optimizer { + /// Fetch an analysis if cached, or run it if not. pub fn analysis(&mut self) -> Rc { - let analyses = self.analyses.clone(); + let analyses = self.analysis_cache.clone(); analyses.get(self) } + /// Invalidate an analysis by removing it from the cache. The analysis is rerun when requested + /// again. pub fn invalidate_analysis(&self) { - self.analyses.invalidate::(); + self.analysis_cache.invalidate::(); } + /// Invalidate all analyses that rely on the structure of the control flow graph. pub fn invalidate_structure(&self) { self.invalidate_analysis::(); self.invalidate_analysis::(); diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index 37db0fea0..9ecd70aff 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -16,8 +16,8 @@ const DEBUG_GVN: bool = false; /// Debug display for the program state. impl Display for Optimizer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let slices = self.analyses.try_get::().unwrap_or_default(); - let ranges = self.analyses.try_get::().unwrap_or_default(); + let slices = self.analysis_cache.try_get::().unwrap_or_default(); + let ranges = self.analysis_cache.try_get::().unwrap_or_default(); f.write_str("Slices:\n")?; for (var_id, slice) in slices.iter() { @@ -33,9 +33,12 @@ impl Display for Optimizer { } f.write_str("\n\n")?; - let global_nums = self.analyses.try_get::().unwrap_or_default(); + let global_nums = self + .analysis_cache + .try_get::() + .unwrap_or_default(); let liveness = self - .analyses + .analysis_cache .try_get::() .unwrap_or_else(|| Rc::new(Liveness::empty(self))); diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 4543f9edd..995fc619f 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -30,7 +30,7 @@ use std::{ sync::atomic::{AtomicUsize, Ordering}, }; -use analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes, Analyses}; +use analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes, AnalysisCache}; use cubecl_core::{ ir::{self as core, Allocator, Branch, Id, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -128,7 +128,7 @@ pub struct Optimizer { /// Allocator for kernel pub allocator: Allocator, /// Analyses with persistent state - analyses: Rc, + analysis_cache: Rc, /// The current block while parsing current_block: Option, /// The current loop's break target @@ -154,7 +154,7 @@ impl Default for Optimizer { root_scope: Scope::root(), cube_dim: Default::default(), mode: Default::default(), - analyses: Default::default(), + analysis_cache: Default::default(), } } } @@ -222,6 +222,8 @@ impl Optimizer { if let Some(current_block) = self.current_block { self.program.add_edge(current_block, self.ret, ()); } + // Analyses shouldn't have run at this point, but just in case they have, invalidate + // all analyses that depend on the graph self.invalidate_structure(); }