Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor optimizer for maintainability/extensability #393

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions crates/cubecl-core/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ impl Instruction {
}
}

impl Operation {
/// Whether this operation is pure, aka has no side effects. Pure operations can be removed
/// if their output is not needed, impure operations must be kept since their execution can
/// affect things down the line. e.g. atomics.
///
/// Operations that operate across multiple units are always considered impure.
pub fn is_pure(&self) -> bool {
match self {
Operation::Copy(_) => true,
Operation::Operator(_) => true,
Operation::Atomic(_) => false,
Operation::Metadata(_) => true,
Operation::Branch(_) => false,
Operation::Synchronization(_) => false,
Operation::Plane(_) => false,
Operation::CoopMma(_) => false,
Operation::NonSemantic(_) => false,
}
}
}

impl Display for Instruction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.operation {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub use device::*;
pub use runtime::*;

#[cfg(test)]
#[allow(unexpected_cfgs)]
mod tests {
pub type TestRuntime = crate::CudaRuntime;
pub use half::{bf16, f16};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ pub fn test_matmul_algorithm<A, EG, ES, R>(
if A::check_availability::<R, (EG, ES, f32)>(&client, &config).is_err() {
// Can't execute the test.
println!("Skipped - not supported!");
client.flush();
return;
}

Expand Down
6 changes: 1 addition & 5 deletions crates/cubecl-macros/src/parse/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,7 @@ impl Unroll {
pub value: Expr,
}

let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll"));
let attr = match attr {
Some(attr) => attr,
None => return None,
};
let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll"))?;

match &attr.meta {
syn::Meta::Path(_) => None,
Expand Down
9 changes: 3 additions & 6 deletions crates/cubecl-opt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
authors = ["Genna Wingert"]
categories = ["algorithms"]
description = "Compiler optimizations for CubeCL"
keywords = ["gpu", "compiler"]
edition = "2021"
keywords = ["gpu", "compiler"]
license.workspace = true
name = "cubecl-opt"
readme.workspace = true
repository = "https://github.com/tracel-ai/cubecl/tree/main/cubecl-opt"
version.workspace = true

[features]
default = [
"std",
"cubecl-common/default",
"cubecl-core/default",
]
default = ["std", "cubecl-common/default", "cubecl-core/default"]
std = ["cubecl-common/std", "cubecl-core/std"]

[dependencies]
Expand All @@ -27,3 +23,4 @@ num = "0.4"
petgraph = { version = "0.6" }
smallvec = { version = "1", features = ["union", "const_generics"] }
stable-vec = { version = "0.4" }
type-map = { version = "0.5" }
62 changes: 62 additions & 0 deletions crates/cubecl-opt/src/analyses/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::{any::Any, cell::RefCell, rc::Rc};

use type_map::TypeMap;

use crate::Optimizer;

use super::{
dominance::{Dominators, PostDominators},
liveness::Liveness,
post_order::PostOrder,
};

/// An analysis used by optimization passes. Unlike optimization passes, analyses can have state
/// and persist until they're invalidated.
pub trait Analysis {
/// Perform the analysis for the current optimizer state and return the persistent analysis state
fn init(opt: &mut Optimizer) -> Self;
}

#[derive(Default, Clone, Debug)]
pub struct Analyses {
cache: Rc<RefCell<TypeMap>>,
}

impl Analyses {
pub fn get<A: Analysis + Any>(&self, opt: &mut Optimizer) -> Rc<A> {
let analysis = self.cache.borrow().get::<Rc<A>>().cloned();
if let Some(analysis) = analysis {
analysis
} else {
let analysis = Rc::new(A::init(opt));
self.cache.borrow_mut().insert(analysis.clone());
analysis
}
}

pub fn try_get<A: Any>(&self) -> Option<Rc<A>> {
self.cache.borrow().get().cloned()
}

pub fn invalidate<A: Analysis + Any>(&self) {
self.cache.borrow_mut().remove::<Rc<A>>();
}
}

impl Optimizer {
pub fn analysis<A: Analysis + Any>(&mut self) -> Rc<A> {
let analyses = self.analyses.clone();
analyses.get(self)
}

pub fn invalidate_analysis<A: Analysis + Any>(&self) {
self.analyses.invalidate::<A>();
}

pub fn invalidate_structure(&self) {
self.invalidate_analysis::<PostOrder>();
self.invalidate_analysis::<Dominators>();
self.invalidate_analysis::<PostDominators>();
self.invalidate_analysis::<Liveness>();
}
}
96 changes: 96 additions & 0 deletions crates/cubecl-opt/src/analyses/const_len.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::{collections::HashMap, ops::Deref};

use cubecl_core::ir::{Id, Operation, Operator, Variable, VariableKind};

use crate::Optimizer;

use super::Analysis;

#[derive(Debug, Clone)]
pub struct Slice {
pub start: Variable,
pub end: Variable,
pub end_op: Option<Operation>,
pub const_len: Option<u32>,
}

/// Try to find any constant length slices by cancelling common factors in `start` and `end`
#[derive(Default, Debug)]
pub struct Slices {
slices: HashMap<Id, Slice>,
}

impl Deref for Slices {
type Target = HashMap<Id, Slice>;

fn deref(&self) -> &Self::Target {
&self.slices
}
}

impl Analysis for Slices {
fn init(opt: &mut Optimizer) -> Self {
let mut this = Slices::default();
this.populate_slices(opt);
this.find_end_ops(opt);
this
}
}

impl Slices {
fn populate_slices(&mut self, opt: &mut Optimizer) {
for block in opt.node_ids() {
let ops = opt.program[block].ops.clone();
for operator in ops.borrow().values() {
let op = match &operator.operation {
Operation::Operator(op) => op,
_ => continue,
};
let out = operator.out.as_ref();
if let Operator::Slice(slice_op) = op {
let out_id = match out.unwrap().kind {
VariableKind::Slice { id } => id,
_ => unreachable!(),
};
let const_len = slice_op.start.as_const().zip(slice_op.end.as_const());
let const_len = const_len.map(|(start, end)| end.as_u32() - start.as_u32());
self.slices.insert(
out_id,
Slice {
start: slice_op.start,
end: slice_op.end,
end_op: None,
const_len,
},
);
};
}
}
}

fn find_end_ops(&mut self, opt: &mut Optimizer) {
for block in opt.node_ids() {
let ops = opt.program[block].ops.clone();
for operator in ops.borrow().values() {
let op = match &operator.operation {
Operation::Operator(op) => op,
_ => continue,
};
// Only handle the simplest cases for now
if let Operator::Add(op) = op {
let mut slices = self.slices.values_mut();
let slice =
slices.find(|it| it.end == operator.out() && it.const_len.is_none());
if let Some(slice) = slice {
slice.end_op = Some(Operator::Add(op.clone()).into());
if op.lhs == slice.start && op.rhs.as_const().is_some() {
slice.const_len = Some(op.rhs.as_const().unwrap().as_u32());
} else if op.rhs == slice.start && op.lhs.as_const().is_some() {
slice.const_len = Some(op.lhs.as_const().unwrap().as_u32());
}
}
};
}
}
}
}
87 changes: 87 additions & 0 deletions crates/cubecl-opt/src/analyses/dominance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use std::{
collections::{HashMap, HashSet},
ops::Deref,
};

use crate::{NodeIndex, Optimizer};
use petgraph::algo::dominators;

use super::Analysis;

/// Dominator tree for the program graph
pub struct Dominators(dominators::Dominators<NodeIndex>);
/// Post dominator tree for the program graph
pub struct PostDominators(dominators::Dominators<NodeIndex>);

impl Deref for Dominators {
type Target = dominators::Dominators<NodeIndex>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Deref for PostDominators {
type Target = dominators::Dominators<NodeIndex>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Analysis for Dominators {
fn init(opt: &mut crate::Optimizer) -> Self {
Dominators(dominators::simple_fast(&opt.program.graph, opt.entry()))
}
}

impl Analysis for PostDominators {
fn init(opt: &mut crate::Optimizer) -> Self {
let mut reversed = opt.program.graph.clone();
reversed.reverse();
PostDominators(dominators::simple_fast(&reversed, opt.ret))
}
}

/// Dominance frontiers for each block
pub struct DomFrontiers {
/// The dominance frontiers of each block (where phi nodes must be inserted).
dom_frontiers: HashMap<NodeIndex, HashSet<NodeIndex>>,
}

impl Deref for DomFrontiers {
type Target = HashMap<NodeIndex, HashSet<NodeIndex>>;

fn deref(&self) -> &Self::Target {
&self.dom_frontiers
}
}

impl DomFrontiers {
/// Find dominance frontiers for each block
pub fn new(opt: &mut Optimizer) -> Self {
let doms = opt.analysis::<Dominators>();
let nodes = opt.node_ids().into_iter().map(|it| (it, HashSet::new()));
let mut dom_frontiers: HashMap<NodeIndex, HashSet<NodeIndex>> = nodes.collect();

for node in opt.node_ids() {
let predecessors = opt.predecessors(node);
if predecessors.len() >= 2 {
for predecessor in predecessors {
let mut runner = predecessor;
while runner != doms.immediate_dominator(node).unwrap() {
dom_frontiers.get_mut(&runner).unwrap().insert(node);
runner = doms.immediate_dominator(runner).unwrap();
}
}
}
}
Self { dom_frontiers }
}
}

impl Analysis for DomFrontiers {
fn init(opt: &mut Optimizer) -> Self {
DomFrontiers::new(opt)
}
}
Loading