diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index acdf7feb..79e188a6 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -54,6 +54,27 @@ impl Instruction { } } +impl Operation { + /// Whether this operation is pure, aka has no side effects. Pure operations can be removed + /// if their output is not needed, impure operations must be kept since their execution can + /// affect things down the line. e.g. atomics. + /// + /// Operations that operate across multiple units are always considered impure. + pub fn is_pure(&self) -> bool { + match self { + Operation::Copy(_) => true, + Operation::Operator(_) => true, + Operation::Atomic(_) => false, + Operation::Metadata(_) => true, + Operation::Branch(_) => false, + Operation::Synchronization(_) => false, + Operation::Plane(_) => false, + Operation::CoopMma(_) => false, + Operation::NonSemantic(_) => false, + } + } +} + impl Display for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.operation { diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index 8b1e36e0..1fb40d78 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -10,6 +10,7 @@ pub use device::*; pub use runtime::*; #[cfg(test)] +#[allow(unexpected_cfgs)] mod tests { pub type TestRuntime = crate::CudaRuntime; pub use half::{bf16, f16}; diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs index 0519d967..55107353 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs @@ -104,6 +104,7 @@ pub fn test_matmul_algorithm( if A::check_availability::(&client, &config).is_err() { // Can't execute the test. println!("Skipped - not supported!"); + client.flush(); return; } diff --git a/crates/cubecl-macros/src/parse/helpers.rs b/crates/cubecl-macros/src/parse/helpers.rs index 88f0ee9d..83477e82 100644 --- a/crates/cubecl-macros/src/parse/helpers.rs +++ b/crates/cubecl-macros/src/parse/helpers.rs @@ -51,11 +51,7 @@ impl Unroll { pub value: Expr, } - let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll")); - let attr = match attr { - Some(attr) => attr, - None => return None, - }; + let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll"))?; match &attr.meta { syn::Meta::Path(_) => None, diff --git a/crates/cubecl-opt/Cargo.toml b/crates/cubecl-opt/Cargo.toml index 752333df..2587641a 100644 --- a/crates/cubecl-opt/Cargo.toml +++ b/crates/cubecl-opt/Cargo.toml @@ -2,8 +2,8 @@ authors = ["Genna Wingert"] categories = ["algorithms"] description = "Compiler optimizations for CubeCL" -keywords = ["gpu", "compiler"] edition = "2021" +keywords = ["gpu", "compiler"] license.workspace = true name = "cubecl-opt" readme.workspace = true @@ -11,11 +11,7 @@ repository = "https://github.com/tracel-ai/cubecl/tree/main/cubecl-opt" version.workspace = true [features] -default = [ - "std", - "cubecl-common/default", - "cubecl-core/default", -] +default = ["std", "cubecl-common/default", "cubecl-core/default"] std = ["cubecl-common/std", "cubecl-core/std"] [dependencies] @@ -27,3 +23,4 @@ num = "0.4" petgraph = { version = "0.6" } smallvec = { version = "1", features = ["union", "const_generics"] } stable-vec = { version = "0.4" } +type-map = { version = "0.5" } diff --git a/crates/cubecl-opt/src/analyses/base.rs b/crates/cubecl-opt/src/analyses/base.rs new file mode 100644 index 00000000..93a05e0d --- /dev/null +++ b/crates/cubecl-opt/src/analyses/base.rs @@ -0,0 +1,62 @@ +use std::{any::Any, cell::RefCell, rc::Rc}; + +use type_map::TypeMap; + +use crate::Optimizer; + +use super::{ + dominance::{Dominators, PostDominators}, + liveness::Liveness, + post_order::PostOrder, +}; + +/// An analysis used by optimization passes. Unlike optimization passes, analyses can have state +/// and persist until they're invalidated. +pub trait Analysis { + /// Perform the analysis for the current optimizer state and return the persistent analysis state + fn init(opt: &mut Optimizer) -> Self; +} + +#[derive(Default, Clone, Debug)] +pub struct Analyses { + cache: Rc>, +} + +impl Analyses { + pub fn get(&self, opt: &mut Optimizer) -> Rc { + let analysis = self.cache.borrow().get::>().cloned(); + if let Some(analysis) = analysis { + analysis + } else { + let analysis = Rc::new(A::init(opt)); + self.cache.borrow_mut().insert(analysis.clone()); + analysis + } + } + + pub fn try_get(&self) -> Option> { + self.cache.borrow().get().cloned() + } + + pub fn invalidate(&self) { + self.cache.borrow_mut().remove::>(); + } +} + +impl Optimizer { + pub fn analysis(&mut self) -> Rc { + let analyses = self.analyses.clone(); + analyses.get(self) + } + + pub fn invalidate_analysis(&self) { + self.analyses.invalidate::(); + } + + pub fn invalidate_structure(&self) { + self.invalidate_analysis::(); + self.invalidate_analysis::(); + self.invalidate_analysis::(); + self.invalidate_analysis::(); + } +} diff --git a/crates/cubecl-opt/src/analyses/const_len.rs b/crates/cubecl-opt/src/analyses/const_len.rs new file mode 100644 index 00000000..b37beaba --- /dev/null +++ b/crates/cubecl-opt/src/analyses/const_len.rs @@ -0,0 +1,96 @@ +use std::{collections::HashMap, ops::Deref}; + +use cubecl_core::ir::{Id, Operation, Operator, Variable, VariableKind}; + +use crate::Optimizer; + +use super::Analysis; + +#[derive(Debug, Clone)] +pub struct Slice { + pub start: Variable, + pub end: Variable, + pub end_op: Option, + pub const_len: Option, +} + +/// Try to find any constant length slices by cancelling common factors in `start` and `end` +#[derive(Default, Debug)] +pub struct Slices { + slices: HashMap, +} + +impl Deref for Slices { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.slices + } +} + +impl Analysis for Slices { + fn init(opt: &mut Optimizer) -> Self { + let mut this = Slices::default(); + this.populate_slices(opt); + this.find_end_ops(opt); + this + } +} + +impl Slices { + fn populate_slices(&mut self, opt: &mut Optimizer) { + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for operator in ops.borrow().values() { + let op = match &operator.operation { + Operation::Operator(op) => op, + _ => continue, + }; + let out = operator.out.as_ref(); + if let Operator::Slice(slice_op) = op { + let out_id = match out.unwrap().kind { + VariableKind::Slice { id } => id, + _ => unreachable!(), + }; + let const_len = slice_op.start.as_const().zip(slice_op.end.as_const()); + let const_len = const_len.map(|(start, end)| end.as_u32() - start.as_u32()); + self.slices.insert( + out_id, + Slice { + start: slice_op.start, + end: slice_op.end, + end_op: None, + const_len, + }, + ); + }; + } + } + } + + fn find_end_ops(&mut self, opt: &mut Optimizer) { + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for operator in ops.borrow().values() { + let op = match &operator.operation { + Operation::Operator(op) => op, + _ => continue, + }; + // Only handle the simplest cases for now + if let Operator::Add(op) = op { + let mut slices = self.slices.values_mut(); + let slice = + slices.find(|it| it.end == operator.out() && it.const_len.is_none()); + if let Some(slice) = slice { + slice.end_op = Some(Operator::Add(op.clone()).into()); + if op.lhs == slice.start && op.rhs.as_const().is_some() { + slice.const_len = Some(op.rhs.as_const().unwrap().as_u32()); + } else if op.rhs == slice.start && op.lhs.as_const().is_some() { + slice.const_len = Some(op.lhs.as_const().unwrap().as_u32()); + } + } + }; + } + } + } +} diff --git a/crates/cubecl-opt/src/analyses/dominance.rs b/crates/cubecl-opt/src/analyses/dominance.rs new file mode 100644 index 00000000..b58d05a4 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/dominance.rs @@ -0,0 +1,87 @@ +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, +}; + +use crate::{NodeIndex, Optimizer}; +use petgraph::algo::dominators; + +use super::Analysis; + +/// Dominator tree for the program graph +pub struct Dominators(dominators::Dominators); +/// Post dominator tree for the program graph +pub struct PostDominators(dominators::Dominators); + +impl Deref for Dominators { + type Target = dominators::Dominators; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Deref for PostDominators { + type Target = dominators::Dominators; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Analysis for Dominators { + fn init(opt: &mut crate::Optimizer) -> Self { + Dominators(dominators::simple_fast(&opt.program.graph, opt.entry())) + } +} + +impl Analysis for PostDominators { + fn init(opt: &mut crate::Optimizer) -> Self { + let mut reversed = opt.program.graph.clone(); + reversed.reverse(); + PostDominators(dominators::simple_fast(&reversed, opt.ret)) + } +} + +/// Dominance frontiers for each block +pub struct DomFrontiers { + /// The dominance frontiers of each block (where phi nodes must be inserted). + dom_frontiers: HashMap>, +} + +impl Deref for DomFrontiers { + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.dom_frontiers + } +} + +impl DomFrontiers { + /// Find dominance frontiers for each block + pub fn new(opt: &mut Optimizer) -> Self { + let doms = opt.analysis::(); + let nodes = opt.node_ids().into_iter().map(|it| (it, HashSet::new())); + let mut dom_frontiers: HashMap> = nodes.collect(); + + for node in opt.node_ids() { + let predecessors = opt.predecessors(node); + if predecessors.len() >= 2 { + for predecessor in predecessors { + let mut runner = predecessor; + while runner != doms.immediate_dominator(node).unwrap() { + dom_frontiers.get_mut(&runner).unwrap().insert(node); + runner = doms.immediate_dominator(runner).unwrap(); + } + } + } + } + Self { dom_frontiers } + } +} + +impl Analysis for DomFrontiers { + fn init(opt: &mut Optimizer) -> Self { + DomFrontiers::new(opt) + } +} diff --git a/crates/cubecl-opt/src/passes/integer_range_analysis.rs b/crates/cubecl-opt/src/analyses/integer_range.rs similarity index 50% rename from crates/cubecl-opt/src/passes/integer_range_analysis.rs rename to crates/cubecl-opt/src/analyses/integer_range.rs index e1308137..3ab92841 100644 --- a/crates/cubecl-opt/src/passes/integer_range_analysis.rs +++ b/crates/cubecl-opt/src/analyses/integer_range.rs @@ -1,22 +1,41 @@ -use std::ops::{Add, Mul, Sub}; +use std::{ + collections::HashMap, + ops::{Add, Mul, Sub}, +}; use cubecl_core::ir::{ - Builtin, ConstantScalarValue, Elem, Id, Operation, Operator, UIntKind, Variable, VariableKind, + Builtin, ConstantScalarValue, Elem, Id, Operation, Operator, Variable, VariableKind, }; -use crate::{AtomicCounter, Optimizer, Range}; +use crate::{Optimizer, VarId}; + +use super::Analysis; -use super::OptimizerPass; +#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)] +pub struct Range { + pub lower_bound: Option, + pub upper_bound: Option, +} + +/// Perform analysis on the possible ranges of integer values and store the results for use in later +/// optimization passes. Reasons for integers being bounded but not constant might be: the modulo +/// operator (bounds it to `0..m`), or `UNIT_POS` (bounded by `CubeDim`). Bounds can be transferred +/// between simple arithmetic, so we can determine the possible range of a good number of variables. +/// This is currently only used in index bound analysis. +#[derive(Debug, Default)] +pub struct Ranges { + int_ranges: HashMap, +} impl Range { - fn constant(val: i64) -> Self { + fn constant(val: u64) -> Self { Self { lower_bound: Some(val), upper_bound: Some(val), } } - fn uint(upper: i64) -> Self { + fn uint(upper: u64) -> Self { Self { lower_bound: Some(0), upper_bound: Some(upper), @@ -24,16 +43,17 @@ impl Range { } } -/// Perform analysis on the possible ranges of integer values and store the results for use in later -/// optimization passes. Reasons for integers being bounded but not constant might be: the modulo -/// operator (bounds it to `0..m`), or `UNIT_POS` (bounded by `CubeDim`). Bounds can be transferred -/// between simple arithmetic, so we can determine the possible range of a good number of variables. -/// This is currently only used in index bound analysis. -#[derive(Default, Clone, Debug)] -pub struct IntegerRangeAnalysis; +impl Analysis for Ranges { + fn init(opt: &mut Optimizer) -> Self { + let mut this = Ranges::default(); + // Run fixed point iteration + while this.run_loop(opt) {} + this + } +} -impl OptimizerPass for IntegerRangeAnalysis { - fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { +impl Ranges { + fn run_loop(&mut self, opt: &mut Optimizer) -> bool { for block in opt.node_ids() { let ops = opt.program[block].ops.clone(); for inst in ops.borrow().values() { @@ -42,58 +62,58 @@ impl OptimizerPass for IntegerRangeAnalysis { _ => continue, }; match op { - Operator::Add(binop) if inst.item().elem().is_int() => { + Operator::Add(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range + rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Sub(binop) if inst.item().elem().is_int() => { + Operator::Sub(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range - rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Mul(binop) if inst.item().elem().is_int() => { + Operator::Mul(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range * rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Div(binop) if inst.item().elem().is_int() => { + Operator::Div(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range: Range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range / rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } - Operator::Modulo(binop) if inst.item().elem().is_int() => { + Operator::Modulo(binop) if is_uint(inst.item().elem()) => { if let Some(out_id) = var_id(&inst.out()) { - let lhs_range = range_of(opt, &binop.lhs); - let rhs_range = range_of(opt, &binop.rhs); + let lhs_range = self.range_of(opt, &binop.lhs); + let rhs_range = self.range_of(opt, &binop.rhs); let out_range = lhs_range % rhs_range; - if Some(&out_range) != opt.program.int_ranges.get(&out_id) { - opt.program.int_ranges.insert(out_id, out_range); - changes.inc(); + if Some(&out_range) != self.int_ranges.get(&out_id) { + self.int_ranges.insert(out_id, out_range); + return true; } } } @@ -101,60 +121,55 @@ impl OptimizerPass for IntegerRangeAnalysis { } } } + false } } -/// The possible range of values of any variable, if applicable. Returns unbounded range if no range -/// can be determined, or the type is not an integer. -pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { - match var.kind { - VariableKind::Versioned { id, version } if var.item.elem() == Elem::UInt(UIntKind::U32) => { - opt.program +fn is_uint(elem: Elem) -> bool { + matches!(elem, Elem::UInt(_)) +} + +impl Ranges { + /// The possible range of values of any variable, if applicable. Returns unbounded range if no range + /// can be determined, or the type is not an integer. + pub fn range_of(&self, opt: &Optimizer, var: &Variable) -> Range { + match var.kind { + VariableKind::Versioned { id, version } if is_uint(var.item.elem()) => self .int_ranges .get(&(id, version)) .copied() .unwrap_or(Range { lower_bound: Some(0), upper_bound: None, + }), + VariableKind::Versioned { id, version } => self + .int_ranges + .get(&(id, version)) + .copied() + .unwrap_or_default(), + VariableKind::LocalConst { id } if is_uint(var.item.elem()) => { + self.int_ranges.get(&(id, 0)).copied().unwrap_or(Range { + lower_bound: Some(0), + upper_bound: None, }) - } - VariableKind::Versioned { id, version } => opt - .program - .int_ranges - .get(&(id, version)) - .copied() - .unwrap_or_default(), - VariableKind::LocalConst { id } if var.item.elem() == Elem::UInt(UIntKind::U32) => opt - .program - .int_ranges - .get(&(id, 0)) - .copied() - .unwrap_or(Range { - lower_bound: Some(0), - upper_bound: None, - }), - VariableKind::LocalConst { id } => opt - .program - .int_ranges - .get(&(id, 0)) - .copied() - .unwrap_or_default(), - VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => Range::constant(val), - VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => { - Range::constant(val as i64) - } - VariableKind::Builtin(builtin) => match builtin { - Builtin::UnitPos => Range::uint(opt.cube_dim.num_elems() as i64 - 1), - Builtin::UnitPosX => Range::uint(opt.cube_dim.x as i64 - 1), - Builtin::UnitPosY => Range::uint(opt.cube_dim.y as i64 - 1), - Builtin::UnitPosZ => Range::uint(opt.cube_dim.z as i64 - 1), - Builtin::CubeCount => Range::constant(opt.cube_dim.num_elems() as i64), - Builtin::CubeCountX => Range::constant(opt.cube_dim.x as i64), - Builtin::CubeCountY => Range::constant(opt.cube_dim.y as i64), - Builtin::CubeCountZ => Range::constant(opt.cube_dim.z as i64), + } + VariableKind::LocalConst { id } => { + self.int_ranges.get(&(id, 0)).copied().unwrap_or_default() + } + VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => Range::constant(val), + VariableKind::Builtin(builtin) => match builtin { + Builtin::UnitPos => Range::uint(opt.cube_dim.num_elems() as u64 - 1), + Builtin::UnitPosX => Range::uint(opt.cube_dim.x as u64 - 1), + Builtin::UnitPosY => Range::uint(opt.cube_dim.y as u64 - 1), + Builtin::UnitPosZ => Range::uint(opt.cube_dim.z as u64 - 1), + Builtin::CubeCount => Range::constant(opt.cube_dim.num_elems() as u64), + Builtin::CubeCountX => Range::constant(opt.cube_dim.x as u64), + Builtin::CubeCountY => Range::constant(opt.cube_dim.y as u64), + Builtin::CubeCountZ => Range::constant(opt.cube_dim.z as u64), + _ => Default::default(), + }, _ => Default::default(), - }, - _ => Default::default(), + } } } @@ -230,28 +245,13 @@ mod range_ops { type Output = Range; fn rem(self, rhs: Self) -> Self::Output { - let min_neg = self.lower_bound.map(|it| it < 0).unwrap_or(true); - let max_pos = self.upper_bound.map(|it| it > 0).unwrap_or(true); if rhs.lower_bound.is_none() || rhs.upper_bound.is_none() { return self; } - let rhs_lower = rhs.lower_bound.unwrap().abs(); - let rhs_upper = rhs.upper_bound.unwrap().abs(); - let rhs_max = rhs_lower.max(rhs_upper); - match (min_neg, max_pos) { - (true, false) => Range { - lower_bound: Some(-(rhs_max - 1)), - upper_bound: Some(0), - }, - (true, true) => Range { - lower_bound: Some(-(rhs_max - 1)), - upper_bound: Some(rhs_max - 1), - }, - (false, true) => Range { - lower_bound: Some(0), - upper_bound: Some(rhs_max - 1), - }, - _ => self, + let rhs_upper = rhs.upper_bound.unwrap(); + Range { + lower_bound: Some(0), + upper_bound: Some(rhs_upper - 1), } } } diff --git a/crates/cubecl-opt/src/analyses/liveness.rs b/crates/cubecl-opt/src/analyses/liveness.rs new file mode 100644 index 00000000..ec94ac60 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/liveness.rs @@ -0,0 +1,106 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use cubecl_core::ir::Id; +use petgraph::graph::NodeIndex; + +use crate::{analyses::post_order::PostOrder, Optimizer}; + +use super::Analysis; + +pub struct Liveness { + live_vars: HashMap>, +} + +#[derive(Clone)] +struct BlockSets { + gen: HashSet, + kill: HashSet, +} + +struct State { + worklist: VecDeque, + block_sets: HashMap, +} + +impl Analysis for Liveness { + fn init(opt: &mut Optimizer) -> Self { + let mut this = Self::empty(opt); + this.analyze_liveness(opt); + this + } +} + +impl Liveness { + pub fn empty(opt: &Optimizer) -> Self { + let live_vars = opt + .node_ids() + .iter() + .map(|it| (*it, HashSet::new())) + .collect(); + Self { live_vars } + } + + pub fn at_block(&self, block: NodeIndex) -> &HashSet { + &self.live_vars[&block] + } + + pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool { + !self.at_block(node).contains(&var) + } + + /// Do a conservative block level liveness analysis + pub fn analyze_liveness(&mut self, opt: &mut Optimizer) { + let mut state = State { + worklist: VecDeque::from(opt.analysis::().forward()), + block_sets: HashMap::new(), + }; + while let Some(block) = state.worklist.pop_front() { + self.analyze_block(opt, block, &mut state); + } + } + + fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) { + let BlockSets { gen, kill } = block_sets(opt, block, state); + + let mut live_vars = gen.clone(); + + for successor in opt.successors(block) { + let successor = &self.live_vars[&successor]; + live_vars.extend(successor.difference(kill)); + } + + if live_vars != self.live_vars[&block] { + state.worklist.extend(opt.predecessors(block)); + self.live_vars.insert(block, live_vars); + } + } +} + +fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets { + let block_sets = state.block_sets.entry(block); + block_sets.or_insert_with(|| calculate_block_sets(opt, block)) +} + +fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets { + let mut gen = HashSet::new(); + let mut kill = HashSet::new(); + + let ops = opt.program[block].ops.clone(); + + for op in ops.borrow_mut().values_mut().rev() { + // Reads must be tracked after writes + opt.visit_out(&mut op.out, |opt, var| { + if let Some(id) = opt.local_variable_id(var) { + kill.insert(id); + gen.remove(&id); + } + }); + opt.visit_operation(&mut op.operation, |opt, var| { + if let Some(id) = opt.local_variable_id(var) { + gen.insert(id); + } + }); + } + + BlockSets { gen, kill } +} diff --git a/crates/cubecl-opt/src/analyses/mod.rs b/crates/cubecl-opt/src/analyses/mod.rs new file mode 100644 index 00000000..654c1d2c --- /dev/null +++ b/crates/cubecl-opt/src/analyses/mod.rs @@ -0,0 +1,9 @@ +mod base; +pub mod const_len; +pub mod dominance; +pub mod integer_range; +pub mod liveness; +pub mod post_order; +pub mod writes; + +pub use base::*; diff --git a/crates/cubecl-opt/src/analyses/post_order.rs b/crates/cubecl-opt/src/analyses/post_order.rs new file mode 100644 index 00000000..6b589015 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/post_order.rs @@ -0,0 +1,23 @@ +use crate::NodeIndex; +use petgraph::visit::{DfsPostOrder, Walker}; + +use super::Analysis; + +pub struct PostOrder(Vec); + +impl Analysis for PostOrder { + fn init(opt: &mut crate::Optimizer) -> Self { + let po = DfsPostOrder::new(&opt.program.graph, opt.entry()); + PostOrder(po.iter(&opt.program.graph).collect()) + } +} + +impl PostOrder { + pub fn forward(&self) -> Vec { + self.0.clone() + } + + pub fn reverse(&self) -> Vec { + self.0.clone().into_iter().rev().collect() + } +} diff --git a/crates/cubecl-opt/src/analyses/writes.rs b/crates/cubecl-opt/src/analyses/writes.rs new file mode 100644 index 00000000..4932c018 --- /dev/null +++ b/crates/cubecl-opt/src/analyses/writes.rs @@ -0,0 +1,45 @@ +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, +}; + +use cubecl_core::ir::Id; + +use crate::{NodeIndex, Optimizer}; + +use super::Analysis; + +pub struct Writes { + /// The variables written to by each block. + writes: HashMap>, +} + +impl Deref for Writes { + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.writes + } +} + +impl Writes { + pub fn new(opt: &mut Optimizer) -> Self { + let nodes = opt.node_ids().into_iter().map(|it| (it, HashSet::new())); + let mut writes: HashMap> = nodes.collect(); + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for inst in ops.borrow().values() { + if let Some(id) = inst.out.as_ref().and_then(|it| opt.local_variable_id(it)) { + writes.get_mut(&block).unwrap().insert(id); + } + } + } + Writes { writes } + } +} + +impl Analysis for Writes { + fn init(opt: &mut crate::Optimizer) -> Self { + Writes::new(opt) + } +} diff --git a/crates/cubecl-opt/src/block.rs b/crates/cubecl-opt/src/block.rs index 9e53c886..3766fdaa 100644 --- a/crates/cubecl-opt/src/block.rs +++ b/crates/cubecl-opt/src/block.rs @@ -1,10 +1,9 @@ -use std::{cell::RefCell, collections::HashSet, rc::Rc}; +use std::{cell::RefCell, rc::Rc}; -use cubecl_core::ir::{Id, Instruction, Variable}; -use petgraph::graph::NodeIndex; +use cubecl_core::ir::{Instruction, Variable}; use stable_vec::StableVec; -use crate::{version::PhiInstruction, ControlFlow, Optimizer, Program}; +use crate::{version::PhiInstruction, ControlFlow, Optimizer}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BlockUse { @@ -19,12 +18,6 @@ pub struct BasicBlock { pub(crate) block_use: Vec, /// The phi nodes that are required to be generated at the start of this block. pub phi_nodes: Rc>>, - /// The variables written to by this block. Only set during the SSA transformation. - pub(crate) writes: HashSet, - /// The live variables at the start of this block. Used for pruning phi nodes. - pub(crate) live_vars: HashSet, - /// The dominance frontiers of this block (where phi nodes must be inserted). - pub(crate) dom_frontiers: HashSet, /// A stable list of operations performed in this block. pub ops: Rc>>, /// The control flow that terminates this block. @@ -61,12 +54,3 @@ impl Optimizer { } } } - -impl Program { - /// Check whether a variable is dead at the start of this block. Note that `false` does not mean - /// the variable is definitely live - just that it *may* be live and must be treated as such. - #[track_caller] - pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool { - !self[node].live_vars.contains(&var) - } -} diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index b4515a64..7f375d06 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -259,8 +259,7 @@ impl Optimizer { let i = range_loop.i; self.program.variables.insert(i_id, i.item); - let mut assign = Instruction::new(Operation::Copy(range_loop.start), i); - self.visit_out(&mut assign.out, |opt, var| opt.write_var(var)); + let assign = Instruction::new(Operation::Copy(range_loop.start), i); self.current_block_mut().ops.borrow_mut().push(assign); let current_block = self.current_block.unwrap(); @@ -302,7 +301,7 @@ impl Optimizer { self.current_block = Some(next); // For loop constructs - self.program.insert_phi(header, i_id, range_loop.start.item); + self.insert_phi(header, i_id, range_loop.start.item); { let op = match range_loop.inclusive { true => Operator::LowerEqual, @@ -351,6 +350,7 @@ impl Optimizer { let new_block = self.program.add_node(BasicBlock::default()); self.program.add_edge(block, new_block, ()); self.program.add_edge(new_block, *successor, ()); + self.invalidate_structure(); update_phi(self, *successor, block, new_block); update_control_flow(self, block, *successor, new_block); } diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index 9cb9a1da..37db0fea 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -1,11 +1,11 @@ -use std::fmt::Display; +use std::{fmt::Display, rc::Rc}; use cubecl_core::ir::{FloatKind, IntKind, UIntKind}; use petgraph::visit::EdgeRef; use crate::{ - gvn::{BlockSets, Constant, Expression, Instruction, Local, OpId, Value, ValueTable}, - passes::var_id, + analyses::{const_len::Slices, integer_range::Ranges, liveness::Liveness}, + gvn::{BlockSets, Constant, Expression, GvnState, Instruction, Local, OpId, Value, ValueTable}, ControlFlow, }; @@ -16,12 +16,11 @@ const DEBUG_GVN: bool = false; /// Debug display for the program state. impl Display for Optimizer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let post_order = self.post_order.iter().map(|it| format!("bb{}", it.index())); - let post_order = post_order.collect::>(); - writeln!(f, "Post Order: {}", post_order.join(", "))?; - writeln!(f)?; + let slices = self.analyses.try_get::().unwrap_or_default(); + let ranges = self.analyses.try_get::().unwrap_or_default(); + f.write_str("Slices:\n")?; - for (var_id, slice) in self.program.slices.iter() { + for (var_id, slice) in slices.iter() { let end_op = slice.end_op.as_ref().map(|it| format!("{it}")); writeln!( f, @@ -34,7 +33,11 @@ impl Display for Optimizer { } f.write_str("\n\n")?; - let global_nums = self.gvn.borrow(); + let global_nums = self.analyses.try_get::().unwrap_or_default(); + let liveness = self + .analyses + .try_get::() + .unwrap_or_else(|| Rc::new(Liveness::empty(self))); if DEBUG_GVN { writeln!(f, "# Value Table:")?; @@ -57,7 +60,7 @@ impl Display for Optimizer { if !bb.block_use.is_empty() { writeln!(f, " Uses: {:?}", bb.block_use)?; } - let live_vars = bb.live_vars.iter(); + let live_vars = liveness.at_block(node).iter(); let live_vars = live_vars.map(|it| format!("local({})", it)); let live_vars = live_vars.collect::>(); writeln!(f, " Live variables: [{}]\n", live_vars.join(", "))?; @@ -75,8 +78,7 @@ impl Display for Optimizer { } for op in bb.ops.borrow_mut().values_mut() { - let id = op.out.and_then(|var| var_id(&var)); - let range = id.and_then(|id| self.program.int_ranges.get(&id)); + let range = op.out.map(|var| ranges.range_of(self, &var)); let range = range.map(|it| format!(" range: {it};")).unwrap_or_default(); writeln!(f, " {op};{range}")?; diff --git a/crates/cubecl-opt/src/gvn/analysis.rs b/crates/cubecl-opt/src/gvn/analysis.rs index 4e54f64e..c3bec225 100644 --- a/crates/cubecl-opt/src/gvn/analysis.rs +++ b/crates/cubecl-opt/src/gvn/analysis.rs @@ -1,18 +1,36 @@ -use std::collections::{HashMap, HashSet, LinkedList}; - -use petgraph::{ - algo::dominators::{self}, - graph::NodeIndex, - Graph, +use std::{ + cell::RefCell, + collections::{HashMap, HashSet, LinkedList}, }; + +use crate::{analyses::Analysis, NodeIndex}; use smallvec::SmallVec; -use crate::{BasicBlock, Optimizer}; +use crate::{ + analyses::dominance::{Dominators, PostDominators}, + Optimizer, +}; -use super::{convert::value_of_var, Expression, GvnPass, Value, ValueTable}; +use super::{convert::value_of_var, Expression, Value, ValueTable}; const MAX_SET_PASSES: usize = 10; +pub struct GlobalValues(pub RefCell); + +#[derive(Debug, Clone, Default)] +pub struct GvnState { + pub values: ValueTable, + pub block_sets: HashMap, +} + +impl Analysis for GlobalValues { + fn init(opt: &mut Optimizer) -> Self { + let mut this = GvnState::default(); + this.build_sets(opt); + GlobalValues(RefCell::new(this)) + } +} + /// The set annotations for a given block #[derive(Debug, Clone, Default)] pub struct BlockSets { @@ -32,32 +50,15 @@ pub struct BlockSets { pub antic_in: LinkedList<(u32, Expression)>, } -/// Needed for Optimizer `Default`, which is required for SpirvCompiler to implement `Default` -/// (which is required for the `Compiler` implementation) -impl Default for GvnPass { - fn default() -> Self { - let mut dummy = Graph::::new(); - let root = dummy.add_node(BasicBlock::default()); - Self { - values: Default::default(), - block_sets: Default::default(), - dominators: dominators::simple_fast(&dummy, root), - post_doms: dominators::simple_fast(&dummy, root), - } - } -} - -impl GvnPass { +impl GvnState { /// Build set annotations for each block. Executes two steps: /// 1. Forward DFA that generates the available expressions, values and leaders for each block /// 2. Backward fixed-point DFA that generates the anticipated expressions/antileaders for each /// block - pub fn build_sets(&mut self, opt: &Optimizer) { - self.build_block_sets_fwd(opt, self.dominators.root(), HashMap::new()); + pub fn build_sets(&mut self, opt: &mut Optimizer) { + self.build_block_sets_fwd(opt, opt.entry(), HashMap::new()); let mut build_passes = 0; - while self.build_block_sets_bckwd(opt, self.post_doms.root()) - && build_passes < MAX_SET_PASSES - { + while self.build_block_sets_bckwd(opt, opt.ret) && build_passes < MAX_SET_PASSES { build_passes += 1; } @@ -86,7 +87,7 @@ impl GvnPass { /// variables that represent them are also available there. fn build_block_sets_fwd( &mut self, - opt: &Optimizer, + opt: &mut Optimizer, block: NodeIndex, mut leaders: HashMap, ) { @@ -100,6 +101,8 @@ impl GvnPass { // Values already added in this block. Used to deduplicate locally. let mut added_exprs = HashSet::new(); + let dominators = opt.analysis::(); + // Number phi outputs and add the out var as a leader for that value for phi in opt.program[block].phi_nodes.borrow().iter() { let (num, val) = self.values.lookup_or_add_phi(phi); @@ -135,7 +138,7 @@ impl GvnPass { antic_in: Default::default(), }; self.block_sets.insert(block, sets); - let successors: Vec<_> = self.dominators.immediately_dominated_by(block).collect(); + let successors: Vec<_> = dominators.immediately_dominated_by(block).collect(); for successor in successors { // Work around dominator bug if successor != block { @@ -147,8 +150,9 @@ impl GvnPass { /// Do a fixed point data backward flow analysis to find expected expressions at any given /// program point. Iterates through the post-dominator tree because it's the fastest way to /// converge. - fn build_block_sets_bckwd(&mut self, opt: &Optimizer, current: NodeIndex) -> bool { + fn build_block_sets_bckwd(&mut self, opt: &mut Optimizer, current: NodeIndex) -> bool { let mut changed = false; + let post_doms = opt.analysis::(); let successors = opt.successors(current); // Since we have no critical edges, if successors > 1 then they must have only one entry, @@ -162,11 +166,7 @@ impl GvnPass { .map(|child| &self.block_sets[child].antic_in); // Only add expressions expected at all successors to this block's anticipated list for (val, expr) in potential_out { - if rest - .clone() - .map(|child| child.iter().any(|v| v.0 == *val)) - .all(|b| b) - { + if rest.clone().all(|child| child.iter().any(|v| v.0 == *val)) { result.push_back((*val, expr.clone())); } } @@ -224,7 +224,7 @@ impl GvnPass { } self.block_sets.get_mut(¤t).unwrap().antic_in = result; - let predecessors: Vec<_> = self.post_doms.immediately_dominated_by(current).collect(); + let predecessors: Vec<_> = post_doms.immediately_dominated_by(current).collect(); for predecessor in predecessors { // Work around dominator bug if predecessor != current { diff --git a/crates/cubecl-opt/src/gvn/apply.rs b/crates/cubecl-opt/src/gvn/apply.rs index 8b2c7d22..7c4cd814 100644 --- a/crates/cubecl-opt/src/gvn/apply.rs +++ b/crates/cubecl-opt/src/gvn/apply.rs @@ -4,14 +4,15 @@ use cubecl_core::ir::{self, Operation}; use petgraph::graph::NodeIndex; use crate::{ + analyses::dominance::Dominators, gvn::{convert::value_of_var, phi_translate}, version::PhiEntry, AtomicCounter, Optimizer, PhiInstruction, }; -use super::GvnPass; +use super::GvnState; -impl GvnPass { +impl GvnState { /// Find places where an expression is partially but not fully available, and hoist the /// computation into the blocks that do not currently have the value available to make the /// expression fully redundant @@ -21,7 +22,7 @@ impl GvnPass { let mut new_expr = HashMap::new(); - while self.insert_block(opt, self.dominators.root(), &mut new_expr, changes) { + while self.insert_block(opt, opt.entry(), &mut new_expr, changes) { loops += 1; } let inserted = changes.get() - changes_pre; @@ -37,6 +38,7 @@ impl GvnPass { changes: &AtomicCounter, ) -> bool { let mut changed = false; + let dominators = opt.analysis::(); let predecessors = opt.predecessors(current); if predecessors.len() > 1 { @@ -81,7 +83,7 @@ impl GvnPass { } let leaders = &mut self.block_sets.get_mut(&pred).unwrap().leaders; if !leaders.contains_key(&val) { - let new_temp = *opt.allocator.create_local_restricted(expr.item()); + let new_temp = *opt.allocator.create_local(expr.item()); let new_op = ir::Instruction::new(expr.to_operation(leaders), new_temp); opt.program[pred].ops.borrow_mut().push(new_op); leaders.insert(val, value_of_var(&new_temp).unwrap()); @@ -100,7 +102,7 @@ impl GvnPass { let new_phis = new_phis .into_iter() .map(|entries| PhiInstruction { - out: *opt.allocator.create_local_restricted(entries[0].value.item), + out: *opt.allocator.create_local(entries[0].value.item), entries, }) .collect::>(); @@ -120,8 +122,7 @@ impl GvnPass { opt.program[current].phi_nodes.borrow_mut().extend(new_phis); } - let children = self - .dominators + let children = dominators .immediately_dominated_by(current) .collect::>(); for child in children { diff --git a/crates/cubecl-opt/src/gvn/base.rs b/crates/cubecl-opt/src/gvn/base.rs index 2bc33dff..ae3044a5 100644 --- a/crates/cubecl-opt/src/gvn/base.rs +++ b/crates/cubecl-opt/src/gvn/base.rs @@ -2,16 +2,15 @@ use std::collections::HashMap; use cubecl_core::ir::{Builtin, ConstantScalarValue, Elem, FloatKind, Id, IntKind, Item, UIntKind}; use float_ord::FloatOrd; -use petgraph::{ - algo::dominators::{self, Dominators}, - graph::NodeIndex, - visit::{DfsPostOrder, Walker as _}, -}; +use petgraph::graph::NodeIndex; use smallvec::SmallVec; use crate::{passes::OptimizerPass, AtomicCounter, Optimizer, PhiInstruction}; -use super::{convert::value_of_var, BlockSets}; +use super::{convert::value_of_var, GlobalValues}; + +#[derive(Debug, Clone, Default)] +pub struct GvnPass; impl OptimizerPass for GvnPass { fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { @@ -19,14 +18,6 @@ impl OptimizerPass for GvnPass { } } -#[derive(Debug, Clone)] -pub struct GvnPass { - pub values: ValueTable, - pub block_sets: HashMap, - pub dominators: Dominators, - pub post_doms: Dominators, -} - impl GvnPass { /// Run the GVN-PRE algorithm /// 1. Build forward and backward dominator trees @@ -36,26 +27,10 @@ impl GvnPass { /// 4. Replace fully redundant expressions with simple assignments from the leader of that /// expression to `out` pub fn run(&mut self, opt: &mut Optimizer, changes: &AtomicCounter) { - self.build_dominators(opt); - self.build_sets(opt); - self.insert(opt, changes); - self.eliminate(opt, changes); - } - - fn build_dominators(&mut self, opt: &mut Optimizer) { - let post_order = DfsPostOrder::new(&opt.program.graph, opt.entry()) - .iter(&opt.program.graph) - .collect::>(); - for node in opt.node_ids() { - if !post_order.contains(&node) { - opt.program.remove_node(node); - } - } + let analysis = opt.analysis::(); - self.dominators = dominators::simple_fast(&opt.program.graph, opt.entry()); - let mut rev_graph = opt.program.graph.clone(); - rev_graph.reverse(); - self.post_doms = dominators::simple_fast(&rev_graph, opt.ret); + analysis.0.borrow_mut().insert(opt, changes); + analysis.0.borrow_mut().eliminate(opt, changes); } } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index d2f57811..45804a54 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -259,10 +259,4 @@ impl Optimizer { visit_read(self, &mut binop.lhs); visit_read(self, &mut binop.rhs); } - - pub fn write_var(&mut self, var: &mut Variable) { - if let Some(id) = self.local_variable_id(var) { - self.current_block_mut().writes.insert(id); - } - } } diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 525a564b..4543f9ed 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -24,13 +24,13 @@ //! use std::{ - cell::RefCell, - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, VecDeque}, ops::{Deref, DerefMut}, rc::Rc, sync::atomic::{AtomicUsize, Ordering}, }; +use analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes, Analyses}; use cubecl_core::{ ir::{self as core, Allocator, Branch, Id, Operation, Operator, Variable, VariableKind}, CubeDim, @@ -43,12 +43,12 @@ use gvn::GvnPass; use passes::{ CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform, EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables, - EmptyBranchToSelect, FindConstSliceLen, InBoundsToUnchecked, InlineAssignments, - IntegerRangeAnalysis, MergeBlocks, MergeSameExpressions, OptimizerPass, ReduceStrength, - RemoveIndexScalar, + EmptyBranchToSelect, InBoundsToUnchecked, InlineAssignments, MergeBlocks, MergeSameExpressions, + OptimizerPass, ReduceStrength, RemoveIndexScalar, }; use petgraph::{prelude::StableDiGraph, visit::EdgeRef, Direction}; +mod analyses; mod block; mod control_flow; mod debug; @@ -88,14 +88,6 @@ impl AtomicCounter { } } -#[derive(Debug, Clone)] -pub(crate) struct Slice { - pub(crate) start: Variable, - pub(crate) end: Variable, - pub(crate) end_op: Option, - pub(crate) const_len: Option, -} - #[derive(Debug, Clone)] pub struct ConstArray { pub id: Id, @@ -108,10 +100,8 @@ pub struct ConstArray { struct Program { pub const_arrays: Vec, pub variables: HashMap, - pub(crate) slices: HashMap, pub graph: StableDiGraph, root: NodeIndex, - int_ranges: HashMap, } impl Deref for Program { @@ -130,19 +120,15 @@ impl DerefMut for Program { type VarId = (Id, u16); -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)] -struct Range { - lower_bound: Option, - upper_bound: Option, -} - /// An optimizer that applies various analyses and optimization passes to the IR. #[derive(Debug, Clone)] pub struct Optimizer { /// The overall program state program: Program, - /// The post order of the graph for traversal - post_order: Vec, + /// Allocator for kernel + pub allocator: Allocator, + /// Analyses with persistent state + analyses: Rc, /// The current block while parsing current_block: Option, /// The current loop's break target @@ -155,24 +141,20 @@ pub struct Optimizer { pub(crate) cube_dim: CubeDim, /// The execution mode, `Unchecked` skips bounds check optimizations. pub(crate) mode: ExecutionMode, - pub(crate) gvn: Rc>, - - pub allocator: Allocator, } impl Default for Optimizer { fn default() -> Self { Self { program: Default::default(), + allocator: Default::default(), current_block: Default::default(), loop_break: Default::default(), ret: Default::default(), root_scope: Scope::root(), cube_dim: Default::default(), mode: Default::default(), - post_order: Default::default(), - gvn: Default::default(), - allocator: Default::default(), + analyses: Default::default(), } } } @@ -197,8 +179,6 @@ impl Optimizer { fn run_opt(&mut self, expand: Scope) { self.parse_graph(expand); self.split_critical_edges(); - self.determine_postorder(self.entry(), &mut HashSet::new()); - self.analyze_liveness(); self.apply_pre_ssa_passes(); self.exempt_index_assign_locals(); self.ssa_transform(); @@ -210,14 +190,13 @@ impl Optimizer { let arrays_prop = AtomicCounter::new(0); CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone()); if arrays_prop.get() > 0 { - self.analyze_liveness(); + self.invalidate_analysis::(); self.ssa_transform(); self.apply_post_ssa_passes(); } let gvn_count = AtomicCounter::new(0); - let gvn = self.gvn.clone(); - gvn.borrow_mut().apply_post_ssa(self, gvn_count.clone()); + GvnPass.apply_post_ssa(self, gvn_count.clone()); ReduceStrength.apply_post_ssa(self, gvn_count.clone()); CopyTransform.apply_post_ssa(self, gvn_count.clone()); @@ -243,24 +222,7 @@ impl Optimizer { if let Some(current_block) = self.current_block { self.program.add_edge(current_block, self.ret, ()); } - } - - fn determine_postorder(&mut self, block: NodeIndex, visited: &mut HashSet) { - for successor in self.successors(block) { - if !visited.contains(&successor) { - visited.insert(successor); - self.determine_postorder(successor, visited); - } - } - self.post_order.push(block); - } - - pub fn post_order(&self) -> Vec { - self.post_order.clone() - } - - pub fn reverse_post_order(&self) -> Vec { - self.post_order.iter().rev().copied().collect() + self.invalidate_structure(); } fn apply_pre_ssa_passes(&mut self) { @@ -293,15 +255,6 @@ impl Optimizer { Box::new(EliminateDeadBlocks), Box::new(EliminateDeadPhi), ]; - // Passes that only run if execution mode is checked - let checked_passes: Vec> = vec![ - Box::new(IntegerRangeAnalysis), - Box::new(FindConstSliceLen), - Box::new(InBoundsToUnchecked), - ]; - if matches!(self.mode, ExecutionMode::Checked) { - passes.extend(checked_passes); - } loop { let counter = AtomicCounter::default(); @@ -313,6 +266,11 @@ impl Optimizer { break; } } + + // Only replace indexing when checked, since all indexes are unchecked anyways in unchecked + if matches!(self.mode, ExecutionMode::Checked) { + InBoundsToUnchecked.apply_post_ssa(self, AtomicCounter::new(0)); + } } /// Remove non-constant index vectors from SSA transformation because they currently must be @@ -336,14 +294,11 @@ impl Optimizer { } fn ssa_transform(&mut self) { - self.program.fill_dom_frontiers(); - self.program.place_phi_nodes(); + self.place_phi_nodes(); self.version_program(); self.program.variables.clear(); - for block in self.node_ids() { - self.program[block].writes.clear(); - self.program[block].dom_frontiers.clear(); - } + self.invalidate_analysis::(); + self.invalidate_analysis::(); } /// Mutable reference to the current basic block @@ -404,30 +359,9 @@ impl Optimizer { let is_break = processed.operations.contains(&Branch::Break.into()); for mut instruction in processed.operations { - let out = instruction.out; match &mut instruction.operation { Operation::Branch(branch) => self.parse_control_flow(branch.clone()), - Operation::Operator(Operator::Slice(slice_op)) => { - let out_id = match out.unwrap().kind { - VariableKind::Slice { id } => id, - _ => unreachable!(), - }; - let const_len = slice_op.start.as_const().zip(slice_op.end.as_const()); - let const_len = const_len.map(|(start, end)| end.as_u32() - start.as_u32()); - self.program.slices.insert( - out_id, - Slice { - start: slice_op.start, - end: slice_op.end, - end_op: None, - const_len, - }, - ); - self.visit_out(&mut instruction.out, |opt, var| opt.write_var(var)); - self.current_block_mut().ops.borrow_mut().push(instruction); - } _ => { - self.visit_out(&mut instruction.out, |opt, var| opt.write_var(var)); self.current_block_mut().ops.borrow_mut().push(instruction); } } @@ -449,6 +383,7 @@ impl Optimizer { let new_ret = self.program.add_node(BasicBlock::default()); self.program.add_edge(new_ret, self.ret, ()); self.ret = new_ret; + self.invalidate_structure(); new_ret } else { self.ret diff --git a/crates/cubecl-opt/src/passes/array_copy_propagate.rs b/crates/cubecl-opt/src/passes/array_copy_propagate.rs index 5ec3a602..8a5d89b8 100644 --- a/crates/cubecl-opt/src/passes/array_copy_propagate.rs +++ b/crates/cubecl-opt/src/passes/array_copy_propagate.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use cubecl_core::ir::{Id, Instruction, Item, Operation, Operator, Variable, VariableKind}; -use crate::{AtomicCounter, Optimizer}; +use crate::{analyses::writes::Writes, AtomicCounter, Optimizer}; use super::OptimizerPass; @@ -111,9 +111,8 @@ fn replace_const_arrays(opt: &mut Optimizer, arr_id: Id, vars: &[Variable]) { if id == arr_id { let const_index = assign.lhs.as_const().unwrap().as_i64() as usize; let out = vars[const_index]; - let out_id = opt.local_variable_id(&out).unwrap(); *op = Instruction::new(Operation::Copy(assign.rhs), out); - opt.program[block].writes.insert(out_id); + opt.invalidate_analysis::(); } } } diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 84352a5e..55ae811d 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -3,7 +3,10 @@ use cubecl_core::ir::{ VariableKind, }; -use crate::{AtomicCounter, Optimizer, Slice}; +use crate::{ + analyses::const_len::{Slice, Slices}, + AtomicCounter, Optimizer, +}; use super::OptimizerPass; @@ -13,6 +16,8 @@ pub struct ConstOperandSimplify; impl OptimizerPass for ConstOperandSimplify { fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { + let slices = opt.analysis::(); + for node in opt.program.node_indices().collect::>() { let ops = opt.program[node].ops.borrow().indices().collect::>(); @@ -123,7 +128,7 @@ impl OptimizerPass for ConstOperandSimplify { changes.inc(); } VariableKind::Slice { id } => { - let slice = opt.program.slices.get(&id); + let slice = slices.get(&id); if let Some(Slice { const_len: Some(len), .. diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index d2d1f4fb..6ac3ee22 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -7,7 +7,10 @@ use std::{ use cubecl_core::ir::{ConstantScalarValue, Instruction, Operation, VariableKind}; use petgraph::{graph::NodeIndex, visit::EdgeRef}; -use crate::{visit_noop, AtomicCounter, BasicBlock, BlockUse, ControlFlow, Optimizer}; +use crate::{ + analyses::{liveness::Liveness, post_order::PostOrder}, + visit_noop, AtomicCounter, BasicBlock, BlockUse, ControlFlow, Optimizer, +}; use super::OptimizerPass; @@ -28,6 +31,11 @@ fn search_loop(opt: &mut Optimizer) -> bool { for idx in ops { let mut op = opt.program[node].ops.borrow()[idx].clone(); + // Impure operations must be skipped because they can change things even if the output + // is unused + if !op.operation.is_pure() { + continue; + } let mut out = None; let used = Rc::new(AtomicBool::new(false)); opt.visit_out(&mut op.out, |_, var| { @@ -93,6 +101,7 @@ impl OptimizerPass for EliminateConstBranches { } *control_flow.borrow_mut() = ControlFlow::None; + opt.invalidate_structure(); changes.inc(); } ControlFlow::Switch { @@ -116,6 +125,7 @@ impl OptimizerPass for EliminateConstBranches { opt.program.remove_edge(edge); } *control_flow.borrow_mut() = ControlFlow::None; + opt.invalidate_structure(); changes.inc(); } _ => {} @@ -129,26 +139,14 @@ pub struct EliminateDeadBlocks; impl OptimizerPass for EliminateDeadBlocks { fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { - while search_dead_blocks(opt) { - changes.inc(); - } - } -} - -fn search_dead_blocks(opt: &mut Optimizer) -> bool { - for block in opt.node_ids() { - if block != opt.entry() && opt.predecessors(block).is_empty() { - let edges: Vec<_> = opt.program.edges(block).map(|it| it.id()).collect(); - for edge in edges { - opt.program.remove_edge(edge); + let post_order = opt.analysis::().forward(); + for node in opt.node_ids() { + if !post_order.contains(&node) { + opt.program.remove_node(node); + changes.inc(); } - opt.program.remove_node(block); - opt.post_order.retain(|it| *it != block); - return true; } } - - false } /// Eliminates invalid phi nodes left over from other optimizations like branch elimination. @@ -203,7 +201,7 @@ impl OptimizerPass for MergeBlocks { } fn merge_blocks(opt: &mut Optimizer) -> bool { - for block_idx in opt.reverse_post_order() { + for block_idx in opt.analysis::().reverse() { let successors = opt.successors(block_idx); if successors.len() == 1 && can_merge(opt, block_idx, successors[0]) { let mut new_block = BasicBlock::default(); @@ -214,8 +212,6 @@ fn merge_blocks(opt: &mut Optimizer) -> bool { let b_ops = block.ops.borrow().values().cloned().collect::>(); let s_ops = successor.ops.borrow().values().cloned().collect::>(); - new_block.live_vars.extend(block.live_vars.clone()); - new_block.live_vars.extend(successor.live_vars.clone()); new_block.phi_nodes.borrow_mut().extend(b_phi); new_block.phi_nodes.borrow_mut().extend(s_phi); new_block.ops.borrow_mut().extend(b_ops); @@ -237,7 +233,8 @@ fn merge_blocks(opt: &mut Optimizer) -> bool { } *opt.program.node_weight_mut(block_idx).unwrap() = new_block; opt.program.remove_node(successors[0]); - opt.post_order.retain(|it| *it != successors[0]); + opt.invalidate_structure(); + opt.invalidate_analysis::(); update_references(opt, successors[0], block_idx); return true; } diff --git a/crates/cubecl-opt/src/passes/find_safe_indexes.rs b/crates/cubecl-opt/src/passes/find_safe_indexes.rs new file mode 100644 index 00000000..7a9456a7 --- /dev/null +++ b/crates/cubecl-opt/src/passes/find_safe_indexes.rs @@ -0,0 +1,65 @@ +use cubecl_core::ir::{Operation, Operator, Variable, VariableKind}; + +use crate::{ + analyses::{const_len::Slices, integer_range::Ranges}, + AtomicCounter, Optimizer, +}; + +use super::OptimizerPass; + +/// Use the results from integer range analysis to find indexes that are always in bounds, then +/// transform them to unchecked indexes. +pub struct InBoundsToUnchecked; + +impl OptimizerPass for InBoundsToUnchecked { + fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { + let ranges = opt.analysis::(); + + for block in opt.node_ids() { + let ops = opt.program[block].ops.clone(); + for inst in ops.borrow_mut().values_mut() { + let op = match &inst.operation { + Operation::Operator(op) => op, + _ => continue, + }; + match op { + Operator::Index(op) => { + if let Some(const_len) = const_len(opt, &op.lhs) { + let range = ranges.range_of(opt, &op.rhs); + if let Some((_, upper)) = range.lower_bound.zip(range.upper_bound) { + if (upper as u32) < const_len { + inst.operation = Operator::UncheckedIndex(op.clone()).into(); + changes.inc(); + } + } + } + } + Operator::IndexAssign(op) => { + if let Some(const_len) = const_len(opt, &inst.out()) { + let range = ranges.range_of(opt, &op.lhs); + if let Some((_, upper)) = range.lower_bound.zip(range.upper_bound) { + if (upper as u32) < const_len { + inst.operation = + Operator::UncheckedIndexAssign(op.clone()).into(); + changes.inc(); + } + } + } + } + _ => {} + } + } + } + } +} + +fn const_len(opt: &mut Optimizer, var: &Variable) -> Option { + let slices = opt.analysis::(); + match var.kind { + VariableKind::ConstantArray { length, .. } => Some(length), + VariableKind::SharedMemory { length, .. } => Some(length), + VariableKind::LocalArray { length, .. } => Some(length), + VariableKind::Slice { id } => slices.get(&id).and_then(|it| it.const_len), + _ => None, + } +} diff --git a/crates/cubecl-opt/src/passes/in_bounds_analysis.rs b/crates/cubecl-opt/src/passes/in_bounds_analysis.rs deleted file mode 100644 index 61c914cc..00000000 --- a/crates/cubecl-opt/src/passes/in_bounds_analysis.rs +++ /dev/null @@ -1,92 +0,0 @@ -use cubecl_core::ir::{Operation, Operator, Variable, VariableKind}; - -use crate::{AtomicCounter, Optimizer}; - -use super::{range_of, OptimizerPass}; - -/// Try to find any constant length slices by cancelling common factors in `start` and `end` -pub struct FindConstSliceLen; - -impl OptimizerPass for FindConstSliceLen { - fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { - for block in opt.node_ids() { - let ops = opt.program[block].ops.clone(); - for operator in ops.borrow().values() { - let op = match &operator.operation { - Operation::Operator(op) => op, - _ => continue, - }; - // Only handle the simplest cases for now - if let Operator::Add(op) = op { - let mut slices = opt.program.slices.values_mut(); - let slice = - slices.find(|it| it.end == operator.out() && it.const_len.is_none()); - if let Some(slice) = slice { - slice.end_op = Some(Operator::Add(op.clone()).into()); - if op.lhs == slice.start && op.rhs.as_const().is_some() { - slice.const_len = Some(op.rhs.as_const().unwrap().as_u32()); - changes.inc(); - } else if op.rhs == slice.start && op.lhs.as_const().is_some() { - slice.const_len = Some(op.lhs.as_const().unwrap().as_u32()); - changes.inc(); - } - } - } - } - } - } -} - -/// Use the results from integer range analysis to find indexes that are always in bounds, then -/// transform them to unchecked indexes. -pub struct InBoundsToUnchecked; - -impl OptimizerPass for InBoundsToUnchecked { - fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) { - for block in opt.node_ids() { - let ops = opt.program[block].ops.clone(); - for inst in ops.borrow_mut().values_mut() { - let op = match &inst.operation { - Operation::Operator(op) => op, - _ => continue, - }; - match op { - Operator::Index(op) => { - if let Some(const_len) = const_len(opt, &op.lhs) { - let range = range_of(opt, &op.rhs); - if let Some((lower, upper)) = range.lower_bound.zip(range.upper_bound) { - if lower >= 0 && (upper as u32) < const_len { - inst.operation = Operator::UncheckedIndex(op.clone()).into(); - changes.inc(); - } - } - } - } - Operator::IndexAssign(op) => { - if let Some(const_len) = const_len(opt, &inst.out()) { - let range = range_of(opt, &op.lhs); - if let Some((lower, upper)) = range.lower_bound.zip(range.upper_bound) { - if lower >= 0 && (upper as u32) < const_len { - inst.operation = - Operator::UncheckedIndexAssign(op.clone()).into(); - changes.inc(); - } - } - } - } - _ => {} - } - } - } - } -} - -fn const_len(opt: &Optimizer, var: &Variable) -> Option { - match var.kind { - VariableKind::ConstantArray { length, .. } => Some(length), - VariableKind::SharedMemory { length, .. } => Some(length), - VariableKind::LocalArray { length, .. } => Some(length), - VariableKind::Slice { id } => opt.program.slices.get(&id).and_then(|it| it.const_len), - _ => None, - } -} diff --git a/crates/cubecl-opt/src/passes/inlined_if_to_select.rs b/crates/cubecl-opt/src/passes/inlined_if_to_select.rs index 514c14c1..6e6e4721 100644 --- a/crates/cubecl-opt/src/passes/inlined_if_to_select.rs +++ b/crates/cubecl-opt/src/passes/inlined_if_to_select.rs @@ -79,12 +79,11 @@ impl OptimizerPass for EmptyBranchToSelect { opt.program.remove_node(then); opt.program.remove_node(or_else); opt.program.remove_node(merge); - opt.post_order - .retain(|it| *it != then && *it != or_else && *it != merge); for merge_successor in merge_successors { opt.program.add_edge(block, merge_successor, ()); } *opt.program[block].control_flow.borrow_mut() = merge_control; + opt.invalidate_structure(); update_references(opt, merge, block); return true; } diff --git a/crates/cubecl-opt/src/passes/liveness.rs b/crates/cubecl-opt/src/passes/liveness.rs deleted file mode 100644 index df398abe..00000000 --- a/crates/cubecl-opt/src/passes/liveness.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::collections::{HashMap, HashSet, VecDeque}; - -use cubecl_core::ir::Id; -use petgraph::graph::NodeIndex; - -use crate::Optimizer; - -#[derive(Clone)] -struct BlockSets { - gen: HashSet, - kill: HashSet, -} - -struct State { - worklist: VecDeque, - block_sets: HashMap, -} - -impl Optimizer { - /// Do a conservative block level liveness analysis - pub fn analyze_liveness(&mut self) { - let mut state = State { - worklist: VecDeque::from(self.post_order()), - block_sets: HashMap::new(), - }; - while let Some(block) = state.worklist.pop_front() { - self.analyze_block(block, &mut state); - } - } - - fn analyze_block(&mut self, block: NodeIndex, state: &mut State) { - let BlockSets { gen, kill } = self.block_sets(block, state); - - let mut live_vars = gen.clone(); - - for successor in self.successors(block) { - let successor = &self.program[successor].live_vars; - live_vars.extend(successor.difference(kill)); - } - - if live_vars != self.program[block].live_vars { - state.worklist.extend(self.predecessors(block)); - self.program[block].live_vars = live_vars; - } - } - - fn block_sets<'a>(&mut self, block: NodeIndex, state: &'a mut State) -> &'a BlockSets { - let block_sets = state.block_sets.entry(block); - block_sets.or_insert_with(|| self.calculate_block_sets(block)) - } - - fn calculate_block_sets(&mut self, block: NodeIndex) -> BlockSets { - let mut gen = HashSet::new(); - let mut kill = HashSet::new(); - - let ops = self.program[block].ops.clone(); - - for op in ops.borrow_mut().values_mut().rev() { - // Reads must be tracked after writes - self.visit_out(&mut op.out, |opt, var| { - if let Some(id) = opt.local_variable_id(var) { - kill.insert(id); - gen.remove(&id); - } - }); - self.visit_operation(&mut op.operation, |opt, var| { - if let Some(id) = opt.local_variable_id(var) { - gen.insert(id); - } - }); - } - - BlockSets { gen, kill } - } -} diff --git a/crates/cubecl-opt/src/passes/mod.rs b/crates/cubecl-opt/src/passes/mod.rs index c52ad656..a14720d3 100644 --- a/crates/cubecl-opt/src/passes/mod.rs +++ b/crates/cubecl-opt/src/passes/mod.rs @@ -3,11 +3,9 @@ mod composite; mod constant_prop; mod dead_code; mod expression_merge; -mod in_bounds_analysis; +mod find_safe_indexes; mod index_merge; mod inlined_if_to_select; -mod integer_range_analysis; -mod liveness; mod reduce_strength; pub use array_copy_propagate::*; @@ -15,10 +13,9 @@ pub use composite::*; pub use constant_prop::*; pub use dead_code::*; pub use expression_merge::*; -pub use in_bounds_analysis::*; +pub use find_safe_indexes::*; pub use index_merge::*; pub use inlined_if_to_select::*; -pub use integer_range_analysis::*; pub use reduce_strength::*; use crate::AtomicCounter; diff --git a/crates/cubecl-opt/src/passes/reduce_strength.rs b/crates/cubecl-opt/src/passes/reduce_strength.rs index 2fea2b01..88795b2f 100644 --- a/crates/cubecl-opt/src/passes/reduce_strength.rs +++ b/crates/cubecl-opt/src/passes/reduce_strength.rs @@ -56,7 +56,7 @@ impl OptimizerPass for ReduceStrength { changes.inc(); } val if (val + 1).is_power_of_two() => { - let temp = *opt.allocator.create_local_restricted(inst.item()); + let temp = *opt.allocator.create_local(inst.item()); new_ops.push(Instruction::new( Operator::ShiftLeft(BinaryOperator { lhs: dyn_val, @@ -74,7 +74,7 @@ impl OptimizerPass for ReduceStrength { changes.inc(); } val if (val - 1).is_power_of_two() => { - let temp = *opt.allocator.create_local_restricted(inst.item()); + let temp = *opt.allocator.create_local(inst.item()); new_ops.push(Instruction::new( Operator::ShiftLeft(BinaryOperator { lhs: dyn_val, diff --git a/crates/cubecl-opt/src/phi_frontiers.rs b/crates/cubecl-opt/src/phi_frontiers.rs index 2b1a4b3e..469f6227 100644 --- a/crates/cubecl-opt/src/phi_frontiers.rs +++ b/crates/cubecl-opt/src/phi_frontiers.rs @@ -1,49 +1,37 @@ use cubecl_core::ir::{Id, Item, Variable, VariableKind}; -use petgraph::{algo::dominators::simple_fast, graph::NodeIndex, visit::EdgeRef, Direction}; +use petgraph::graph::NodeIndex; -use super::{ - version::{PhiEntry, PhiInstruction}, - Program, +use crate::{ + analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes}, + Optimizer, }; -impl Program { - /// Find dominance frontiers for each block - pub fn fill_dom_frontiers(&mut self) { - let doms = simple_fast(&self.graph, self.root); - for node in self.node_indices().collect::>() { - let predecessors: Vec<_> = self - .edges_directed(node, Direction::Incoming) - .map(|it| it.source()) - .collect(); - if predecessors.len() >= 2 { - for predecessor in predecessors { - let mut runner = predecessor; - while runner != doms.immediate_dominator(node).unwrap() { - self[runner].dom_frontiers.insert(node); - runner = doms.immediate_dominator(runner).unwrap(); - } - } - } - } - } +use super::version::{PhiEntry, PhiInstruction}; +impl Optimizer { /// Places a phi node for each live variable at each frontier pub fn place_phi_nodes(&mut self) { - let keys: Vec<_> = self.variables.keys().cloned().collect(); + let keys: Vec<_> = self.program.variables.keys().cloned().collect(); + let writes = self.analysis::(); + let liveness = self.analysis::(); + let dom_frontiers = self.analysis::(); + for var in keys { let mut workset: Vec<_> = self - .node_indices() - .filter(|index| self[*index].writes.contains(&var)) + .node_ids() + .iter() + .filter(|index| writes[*index].contains(&var)) + .copied() .collect(); let mut considered = workset.clone(); let mut already_inserted = Vec::new(); while let Some(node) = workset.pop() { - for frontier in self[node].dom_frontiers.clone() { - if already_inserted.contains(&frontier) || self.is_dead(frontier, var) { + for frontier in dom_frontiers[&node].clone() { + if already_inserted.contains(&frontier) || liveness.is_dead(frontier, var) { continue; } - self.insert_phi(frontier, var, self.variables[&var]); + self.insert_phi(frontier, var, self.program.variables[&var]); already_inserted.push(frontier); if !considered.contains(&frontier) { workset.push(frontier); @@ -57,17 +45,14 @@ impl Program { /// Insert a phi node for variable `id` at `block` pub fn insert_phi(&mut self, block: NodeIndex, id: Id, item: Item) { let var = Variable::new(VariableKind::Versioned { id, version: 0 }, item); - let entries = self - .edges_directed(block, Direction::Incoming) - .map(|edge| edge.source()) - .map(|pred| PhiEntry { - block: pred, - value: var, - }); + let entries = self.predecessors(block).into_iter().map(|pred| PhiEntry { + block: pred, + value: var, + }); let phi = PhiInstruction { out: var, entries: entries.collect(), }; - self[block].phi_nodes.borrow_mut().push(phi); + self.program[block].phi_nodes.borrow_mut().push(phi); } } diff --git a/crates/cubecl-spirv/src/debug.rs b/crates/cubecl-spirv/src/debug.rs index 5ea78486..8b0cd75a 100644 --- a/crates/cubecl-spirv/src/debug.rs +++ b/crates/cubecl-spirv/src/debug.rs @@ -442,7 +442,7 @@ impl SpirvCompiler { } // Declare entry - let entry_name = debug_info.name_str.clone(); + let entry_name = self.debug_info().name_str.clone(); let entry_def = self.definitions().functions[&entry_name]; let args = self.debug_string(""); let signature = self.debug_string(SIGNATURE); diff --git a/crates/cubecl-spirv/src/instruction.rs b/crates/cubecl-spirv/src/instruction.rs index c4ee906d..d50caf4a 100644 --- a/crates/cubecl-spirv/src/instruction.rs +++ b/crates/cubecl-spirv/src/instruction.rs @@ -439,7 +439,6 @@ impl SpirvCompiler { }); } Operator::ReverseBits(op) => { - self.capabilities.insert(Capability::BitInstructions); self.compile_unary_op(op, out, |b, _, ty, input, out| { b.bit_reverse(ty, Some(out), input).unwrap(); }); diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 50dea369..bd447d1d 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -21,6 +21,7 @@ pub use runtime::*; pub use compiler::spirv; #[cfg(test)] +#[allow(unexpected_cfgs)] mod tests { pub type TestRuntime = crate::WgpuRuntime; @@ -33,14 +34,16 @@ mod tests { } #[cfg(all(test, feature = "spirv"))] +#[allow(unexpected_cfgs)] mod tests_spirv { pub type TestRuntime = crate::WgpuRuntime; use cubecl_core::flex32; use half::f16; cubecl_core::testgen_all!(f32: [f16, flex32, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); - cubecl_linalg::testgen_matmul_plane!([f16, flex32, f32]); - cubecl_linalg::testgen_matmul_tiling2d!([f16, flex32, f32, f64]); - cubecl_linalg::testgen_matmul_simple!([flex32, f32]); + cubecl_linalg::testgen_matmul_plane!([f16, f32]); + cubecl_linalg::testgen_matmul_tiling2d!([f16, f32, f64]); + cubecl_linalg::testgen_matmul_simple!([f32]); cubecl_linalg::testgen_matmul_accelerated!([f16]); + cubecl_reduce::testgen_reduce!(); }