From 6a4db895bd1ec5450396bc16cf1d0084e53ed3c4 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 9 Jan 2025 23:50:02 +0100 Subject: [PATCH] Don't eliminate impure operations when looking for dead code --- crates/cubecl-core/src/ir/operation.rs | 21 +++++++++++++++++++++ crates/cubecl-opt/src/passes/dead_code.rs | 5 +++++ 2 files changed, 26 insertions(+) diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index acdf7feb..79e188a6 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -54,6 +54,27 @@ impl Instruction { } } +impl Operation { + /// Whether this operation is pure, aka has no side effects. Pure operations can be removed + /// if their output is not needed, impure operations must be kept since their execution can + /// affect things down the line. e.g. atomics. + /// + /// Operations that operate across multiple units are always considered impure. + pub fn is_pure(&self) -> bool { + match self { + Operation::Copy(_) => true, + Operation::Operator(_) => true, + Operation::Atomic(_) => false, + Operation::Metadata(_) => true, + Operation::Branch(_) => false, + Operation::Synchronization(_) => false, + Operation::Plane(_) => false, + Operation::CoopMma(_) => false, + Operation::NonSemantic(_) => false, + } + } +} + impl Display for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.operation { diff --git a/crates/cubecl-opt/src/passes/dead_code.rs b/crates/cubecl-opt/src/passes/dead_code.rs index 40d0f1da..db0bd86c 100644 --- a/crates/cubecl-opt/src/passes/dead_code.rs +++ b/crates/cubecl-opt/src/passes/dead_code.rs @@ -31,6 +31,11 @@ fn search_loop(opt: &mut Optimizer) -> bool { for idx in ops { let mut op = opt.program[node].ops.borrow()[idx].clone(); + // Assume operations and metadata are pure, and everything else might have side effects + // Technically not correct but much simpler than + if !op.operation.is_pure() { + continue; + } let mut out = None; let used = Rc::new(AtomicBool::new(false)); opt.visit_out(&mut op.out, |_, var| {