diff --git a/crates/cubecl-core/src/codegen/compiler.rs b/crates/cubecl-core/src/codegen/compiler.rs index 756c8ac69..be26dad60 100644 --- a/crates/cubecl-core/src/codegen/compiler.rs +++ b/crates/cubecl-core/src/codegen/compiler.rs @@ -1,4 +1,4 @@ -use crate::ir::{Elem, KernelDefinition, LocalAllocator}; +use crate::ir::{Allocator, Elem, KernelDefinition}; use cubecl_runtime::ExecutionMode; use std::fmt::Display; @@ -22,7 +22,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { ) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: Elem) -> usize; - fn local_allocator() -> impl LocalAllocator; + fn local_allocator() -> Allocator; /// The maximal size of a shared memory, in bytes fn max_shared_memory_size() -> usize; } diff --git a/crates/cubecl-core/src/codegen/integrator.rs b/crates/cubecl-core/src/codegen/integrator.rs index 88a3d2398..866e1d3d9 100644 --- a/crates/cubecl-core/src/codegen/integrator.rs +++ b/crates/cubecl-core/src/codegen/integrator.rs @@ -412,7 +412,7 @@ impl KernelIntegrator { }); self.expansion.scope.write_global( Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: local, depth: self.expansion.scope.depth, @@ -432,7 +432,7 @@ impl KernelIntegrator { } => { self.expansion.scope.write_global( Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: local, depth: self.expansion.scope.depth, }, diff --git a/crates/cubecl-core/src/compute/builder.rs b/crates/cubecl-core/src/compute/builder.rs index 6cb0ed7f2..156c4af8d 100644 --- a/crates/cubecl-core/src/compute/builder.rs +++ b/crates/cubecl-core/src/compute/builder.rs @@ -1,4 +1,4 @@ -use crate::ir::{Elem, Item, LocalAllocator, ReusingAllocator, Visibility}; +use crate::ir::{Allocator, Elem, Item, Visibility}; use crate::prelude::KernelDefinition; use crate::KernelSettings; use crate::{ @@ -117,7 +117,7 @@ impl KernelBuilder { .integrate(settings) } - pub fn with_local_allocator(allocator: impl LocalAllocator + 'static) -> Self { + pub fn with_local_allocator(allocator: Allocator) -> Self { Self { context: CubeContext::root(allocator), inputs: Vec::new(), @@ -131,6 +131,6 @@ impl KernelBuilder { impl Default for KernelBuilder { fn default() -> Self { - Self::with_local_allocator(ReusingAllocator::default()) + Self::with_local_allocator(Allocator::new()) } } diff --git a/crates/cubecl-core/src/frontend/branch.rs b/crates/cubecl-core/src/frontend/branch.rs index 3c5c9103c..8c1095766 100644 --- a/crates/cubecl-core/src/frontend/branch.rs +++ b/crates/cubecl-core/src/frontend/branch.rs @@ -101,7 +101,7 @@ impl Iterable for RangeExpand { ) { let mut child = context.child(); let index_ty = Item::new(I::as_elem(context)); - let i = child.create_local_undeclared(index_ty); + let i = child.create_local_restricted(index_ty); body(&mut child, i.clone().into()); @@ -131,7 +131,7 @@ impl> Iterable for SteppedRangeExpand { ) { let mut child = context.child(); let index_ty = Item::new(I::as_elem(context)); - let i = child.create_local_undeclared(index_ty); + let i = child.create_local_restricted(index_ty); body(&mut child, i.clone().into()); @@ -396,7 +396,7 @@ pub fn if_else_expr_expand( None => { let mut then_child = context.child(); let ret = then_block(&mut then_child); - let out: ExpandElementTyped = context.create_local_variable(ret.expand.item).into(); + let out: ExpandElementTyped = context.create_local_mut(ret.expand.item).into(); assign::expand(&mut then_child, ret, out.clone()); IfElseExprExpand::Runtime { @@ -501,7 +501,7 @@ pub fn switch_expand_expr( ) -> SwitchExpandExpr { let mut default_child = context.child(); let default = default_block(&mut default_child); - let out: ExpandElementTyped = context.create_local_variable(default.expand.item).into(); + let out: ExpandElementTyped = context.create_local_mut(default.expand.item).into(); assign::expand(&mut default_child, default, out.clone()); SwitchExpandExpr { diff --git a/crates/cubecl-core/src/frontend/container/array/base.rs b/crates/cubecl-core/src/frontend/container/array/base.rs index 580ab2dc1..7a075723b 100644 --- a/crates/cubecl-core/src/frontend/container/array/base.rs +++ b/crates/cubecl-core/src/frontend/container/array/base.rs @@ -180,7 +180,7 @@ mod vectorization { let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8)); let new_var = if factor == 1 { - let new_var = context.create_local_binding(item); + let new_var = context.create_local(item); let element = index::expand( context, self.clone(), @@ -189,7 +189,7 @@ mod vectorization { assign::expand(context, element, new_var.clone().into()); new_var } else { - let new_var = context.create_local_variable(item); + let new_var = context.create_local_mut(item); for i in 0..factor { let expand: Self = self.expand.clone().into(); let element = @@ -230,7 +230,7 @@ mod metadata { impl ExpandElementTyped> { // Expand method of [len](Array::len). pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped { - let out = context.create_local_binding(Item::new(u32::as_elem(context))); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::Length { var: self.expand.into(), @@ -245,7 +245,7 @@ mod metadata { self, context: &mut CubeContext, ) -> ExpandElementTyped { - let out = context.create_local_binding(Item::new(u32::as_elem(context))); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::BufferLength { var: self.expand.into(), @@ -298,7 +298,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/iter.rs b/crates/cubecl-core/src/frontend/container/iter.rs index e65d3b563..f53d2074a 100644 --- a/crates/cubecl-core/src/frontend/container/iter.rs +++ b/crates/cubecl-core/src/frontend/container/iter.rs @@ -31,7 +31,7 @@ impl Iterable for ExpandElementTyped { let len: ExpandElement = T::len(&self.expand, context); let mut child = context.child(); - let i = child.create_local_undeclared(index_ty); + let i = child.create_local_restricted(index_ty); let item = index::expand(&mut child, self, i.clone().into()); body(&mut child, item); diff --git a/crates/cubecl-core/src/frontend/container/line/base.rs b/crates/cubecl-core/src/frontend/container/line/base.rs index 3c901bbef..f2fd51675 100644 --- a/crates/cubecl-core/src/frontend/container/line/base.rs +++ b/crates/cubecl-core/src/frontend/container/line/base.rs @@ -86,8 +86,7 @@ mod fill { value: ExpandElementTyped

, ) -> Self { let length = self.expand.item.vectorization; - let output = - context.create_local_binding(Item::vectorized(P::as_elem(context), length)); + let output = context.create_local(Item::vectorized(P::as_elem(context), length)); cast::expand::

(context, value, output.clone().into()); @@ -128,7 +127,7 @@ mod empty { None => None, }; context - .create_local_variable(Item::vectorized(Self::as_elem(context), length)) + .create_local_mut(Item::vectorized(Self::as_elem(context), length)) .into() } } @@ -216,7 +215,7 @@ macro_rules! impl_line_comparison { let lhs = self.expand.into(); let rhs = rhs.expand.into(); - let output = context.create_local_binding(Item::vectorized(bool::as_elem(context), size)); + let output = context.create_local_mut(Item::vectorized(bool::as_elem(context), size)); context.register(Instruction::new( Operator::$operator(BinaryOperator { lhs, rhs }), diff --git a/crates/cubecl-core/src/frontend/container/shared_memory.rs b/crates/cubecl-core/src/frontend/container/shared_memory.rs index f016d9b20..07372eec5 100644 --- a/crates/cubecl-core/src/frontend/container/shared_memory.rs +++ b/crates/cubecl-core/src/frontend/container/shared_memory.rs @@ -135,7 +135,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/slice.rs b/crates/cubecl-core/src/frontend/container/slice.rs index 7badaad4d..33ac040c4 100644 --- a/crates/cubecl-core/src/frontend/container/slice.rs +++ b/crates/cubecl-core/src/frontend/container/slice.rs @@ -177,7 +177,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, @@ -195,7 +195,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/tensor/base.rs b/crates/cubecl-core/src/frontend/container/tensor/base.rs index 737b25a5c..409883dd0 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/base.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/base.rs @@ -130,7 +130,7 @@ mod metadata { dim: ExpandElementTyped, ) -> ExpandElementTyped { let dim: ExpandElement = dim.into(); - let out = context.create_local_binding(Item::new(u32::as_elem(context))); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::Stride { dim: *dim, @@ -148,7 +148,7 @@ mod metadata { dim: ExpandElementTyped, ) -> ExpandElementTyped { let dim: ExpandElement = dim.into(); - let out = context.create_local_binding(Item::new(u32::as_elem(context))); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::Shape { dim: *dim, @@ -171,7 +171,7 @@ mod metadata { let shape = self.clone().__expand_shape_method(context, dim.clone()); // Compute `num_strides = index / stride`. - let num_strides = context.create_local_binding(Item::new(u32::as_elem(context))); + let num_strides = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Operator::Div(BinaryOperator { lhs: *index, @@ -181,7 +181,7 @@ mod metadata { )); // Compute `coordinate = num_strides % shape `. - let coordinate = context.create_local_binding(Item::new(u32::as_elem(context))); + let coordinate = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Operator::Modulo(BinaryOperator { lhs: *num_strides, @@ -210,7 +210,7 @@ mod metadata { // Expand method of [rank](Tensor::rank). pub fn __expand_rank_method(self, context: &mut CubeContext) -> ExpandElementTyped { - let out = context.create_local_binding(Item::new(u32::as_elem(context))); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out)); out.into() } @@ -258,7 +258,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/context.rs b/crates/cubecl-core/src/frontend/context.rs index 8624742c8..0abb6b031 100644 --- a/crates/cubecl-core/src/frontend/context.rs +++ b/crates/cubecl-core/src/frontend/context.rs @@ -1,5 +1,5 @@ -use crate::ir::{self, Elem, Instruction, Item, ReusingAllocator, Scope, Variable, VariableKind}; -use crate::{frontend::ExpandElement, ir::LocalAllocator}; +use crate::frontend::ExpandElement; +use crate::ir::{self, Allocator, Elem, Instruction, Item, Scope, Variable, VariableKind}; use alloc::rc::Rc; use core::cell::RefCell; use cubecl_runtime::debug::DebugLogger; @@ -9,14 +9,14 @@ use std::collections::HashMap; pub struct CubeContext { pub root: Rc>, pub scope: Rc>, - pub local_allocator: Rc, + pub allocator: Allocator, pub debug_enabled: bool, pub typemap: Rc>>, } impl Default for CubeContext { fn default() -> Self { - Self::root(ReusingAllocator::default()) + Self::root(Allocator::new()) } } @@ -25,13 +25,13 @@ impl CubeContext { /// A root scope is at the root of a compute shader /// Therefore there is one cube context per shader /// The allocator will define the strategy for creating local intermediates and mutable variables - pub fn root(allocator: impl LocalAllocator + 'static) -> CubeContext { + pub fn root(allocator: Allocator) -> CubeContext { let root = Rc::new(RefCell::new(Scope::root())); let typemap = Rc::new(RefCell::new(HashMap::new())); let scope = root.clone(); Self { - local_allocator: Rc::new(allocator), + allocator, scope, root, debug_enabled: DebugLogger::default().is_activated(), @@ -64,7 +64,7 @@ impl CubeContext { Self { scope: Rc::new(RefCell::new(scope)), root: self.root.clone(), - local_allocator: self.local_allocator.clone(), + allocator: self.allocator.clone(), debug_enabled: self.debug_enabled, typemap: self.typemap.clone(), } @@ -78,23 +78,23 @@ impl CubeContext { .into_inner() } - /// Create a new mutable local variable - pub fn create_local_variable(&mut self, item: Item) -> ExpandElement { - self.local_allocator - .create_local_variable(self.root.clone(), self.scope.clone(), item) + /// Create a new mutable local variable. + pub fn create_local_mut(&mut self, item: Item) -> ExpandElement { + self.allocator + .create_local_mut(&mut self.root.borrow_mut(), item) } - /// Create a new immutable local binding - pub fn create_local_binding(&mut self, item: Item) -> ExpandElement { - self.local_allocator - .create_local_binding(self.root.clone(), self.scope.clone(), item) + /// Create a new immutable local variable. + pub fn create_local(&mut self, item: Item) -> ExpandElement { + self.allocator + .create_local(&mut self.scope.borrow_mut(), item) } /// Create a new immutable local binding that must never be a reused variable, regardless of /// allocator - pub fn create_local_undeclared(&mut self, item: Item) -> ExpandElement { - self.local_allocator - .create_local_undeclared(self.root.clone(), self.scope.clone(), item) + pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement { + self.allocator + .create_local_restricted(&mut self.scope.borrow_mut(), item) } /// Create a new matrix element. diff --git a/crates/cubecl-core/src/frontend/element/atomic.rs b/crates/cubecl-core/src/frontend/element/atomic.rs index 45814835d..8fafa2a58 100644 --- a/crates/cubecl-core/src/frontend/element/atomic.rs +++ b/crates/cubecl-core/src/frontend/element/atomic.rs @@ -143,7 +143,7 @@ where pointer: ::ExpandType, ) -> ::ExpandType { let pointer: ExpandElement = pointer.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Load(UnaryOperator { input: *pointer }), *new_var, @@ -171,7 +171,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Swap(BinaryOperator { lhs: *ptr, @@ -191,7 +191,7 @@ where let pointer: ExpandElement = pointer.into(); let cmp: ExpandElement = cmp.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::CompareAndSwap(CompareAndSwapOperator { input: *pointer, @@ -210,7 +210,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Add(BinaryOperator { lhs: *ptr, @@ -228,7 +228,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Sub(BinaryOperator { lhs: *ptr, @@ -246,7 +246,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Max(BinaryOperator { lhs: *ptr, @@ -264,7 +264,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Min(BinaryOperator { lhs: *ptr, @@ -282,7 +282,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::And(BinaryOperator { lhs: *ptr, @@ -300,7 +300,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Or(BinaryOperator { lhs: *ptr, @@ -318,7 +318,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem(context))); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Xor(BinaryOperator { lhs: *ptr, diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index b6745cb75..b7fbc8584 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -327,7 +327,7 @@ impl ExpandElement { pub fn can_mut(&self) -> bool { match self { ExpandElement::Managed(var) => { - if let VariableKind::Local { .. } = var.as_ref().kind { + if let VariableKind::LocalMut { .. } = var.as_ref().kind { Rc::strong_count(var) <= 2 } else { false @@ -379,9 +379,9 @@ pub(crate) fn init_expand_element>( match elem.kind { VariableKind::GlobalScalar { .. } => init(elem), VariableKind::ConstantScalar { .. } => init(elem), - VariableKind::Local { .. } => init(elem), + VariableKind::LocalMut { .. } => init(elem), VariableKind::Versioned { .. } => init(elem), - VariableKind::LocalBinding { .. } => init(elem), + VariableKind::LocalConst { .. } => init(elem), // Constant should be initialized since the new variable can be mutated afterward. // And it is assumed those values are cloned. VariableKind::Builtin(_) => init(elem), diff --git a/crates/cubecl-core/src/frontend/element/cast.rs b/crates/cubecl-core/src/frontend/element/cast.rs index 67371baea..2c7366276 100644 --- a/crates/cubecl-core/src/frontend/element/cast.rs +++ b/crates/cubecl-core/src/frontend/element/cast.rs @@ -18,8 +18,7 @@ pub trait Cast: CubePrimitive { if core::any::TypeId::of::() == core::any::TypeId::of::() { return value.expand.into(); } - - let new_var = context.create_local_binding(Item::vectorized( + let new_var = context.create_local(Item::vectorized( ::as_elem(context), value.expand.item.vectorization, )); @@ -49,7 +48,7 @@ pub trait BitCast: CubePrimitive { ) -> ::ExpandType { let value: ExpandElement = value.into(); let var: Variable = *value; - let new_var = context.create_local_binding(Item::vectorized( + let new_var = context.create_local(Item::vectorized( ::as_elem(context), var.item.vectorization, )); diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index d61320828..6ef446542 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -95,7 +95,7 @@ pub trait Numeric: context: &mut CubeContext, vec: [u32; D], ) -> ::ExpandType { - let new_var = context.create_local_binding(Item::vectorized( + let new_var = context.create_local(Item::vectorized( Self::as_elem(context), NonZero::new(vec.len() as u8), )); diff --git a/crates/cubecl-core/src/frontend/operation/assignation.rs b/crates/cubecl-core/src/frontend/operation/assignation.rs index 8ca4b3a16..51b273a80 100644 --- a/crates/cubecl-core/src/frontend/operation/assignation.rs +++ b/crates/cubecl-core/src/frontend/operation/assignation.rs @@ -143,7 +143,7 @@ pub mod index { let array: ExpandElement = array.into(); let var: Variable = *array; let var = match var.kind { - VariableKind::Local { .. } | VariableKind::LocalBinding { .. } => { + VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => { binary_expand_no_vec(context, array, index, Operator::Index) } _ => binary_expand(context, array, index, Operator::Index), diff --git a/crates/cubecl-core/src/frontend/operation/base.rs b/crates/cubecl-core/src/frontend/operation/base.rs index 59ca59772..2c1cf308b 100644 --- a/crates/cubecl-core/src/frontend/operation/base.rs +++ b/crates/cubecl-core/src/frontend/operation/base.rs @@ -2,7 +2,7 @@ use std::num::NonZeroU8; use crate::ir::{ BinaryOperator, Elem, Instruction, Item, Operation, Operator, UnaryOperator, Variable, - Vectorization, + VariableKind, Vectorization, }; use crate::prelude::{CubeType, ExpandElementTyped}; use crate::{ @@ -29,7 +29,7 @@ where let item = Item::vectorized(item_lhs.elem, vectorization); - let output = context.create_local_binding(item); + let output = context.create_local(item); let out = *output; let op = func(BinaryOperator { lhs, rhs }); @@ -52,7 +52,7 @@ where let lhs_var = lhs.consume(); let rhs_var = rhs.consume(); - let out = context.create_local_binding(out_item); + let out = context.create_local(out_item); let out_var = *out; @@ -82,7 +82,7 @@ where let item = Item::new(item_lhs.elem); - let output = context.create_local_binding(item); + let output = context.create_local(item); let out = *output; let op = func(BinaryOperator { lhs, rhs }); @@ -112,7 +112,7 @@ where vectorization: item.vectorization, }; - let out = context.create_local_binding(out_item); + let out = context.create_local(out_item); let out_var = *out; let op = func(BinaryOperator { lhs, rhs }); @@ -150,7 +150,7 @@ where let input = input.consume(); let item = input.item; - let out = context.create_local_binding(item); + let out = context.create_local(item); let out_var = *out; let op = func(UnaryOperator { input }); @@ -170,7 +170,7 @@ where F: Fn(UnaryOperator) -> Operator, { let input = input.consume(); - let output = context.create_local_binding(out_item); + let output = context.create_local(out_item); let out = *output; let op = func(UnaryOperator { input }); @@ -191,7 +191,7 @@ where let input_var: Variable = *input; let item = input.item; - let out = context.create_local_variable(item); + let out = context.create_local_mut(item); let out_var = *out; let op = func(input_var); @@ -239,7 +239,12 @@ pub fn array_assign_binary_op_expand< let index: ExpandElement = index.into(); let value: ExpandElement = value.into(); - let array_value = context.create_local_binding(array.item); + let array_item = match array.kind { + // In that case, the array is a line. + VariableKind::LocalMut { .. } => array.item.vectorize(None), + _ => array.item, + }; + let array_value = context.create_local(array_item); let read = Instruction::new( Operator::Index(BinaryOperator { @@ -249,7 +254,7 @@ pub fn array_assign_binary_op_expand< *array_value, ); let array_value = array_value.consume(); - let op_out = context.create_local_binding(array.item); + let op_out = context.create_local(array.item); let calculate = Instruction::new( func(BinaryOperator { lhs: array_value, @@ -262,7 +267,6 @@ pub fn array_assign_binary_op_expand< lhs: *index, rhs: op_out.consume(), }); - context.register(read); context.register(calculate); context.register(Instruction::new(write, *array)); diff --git a/crates/cubecl-core/src/frontend/operation/branch.rs b/crates/cubecl-core/src/frontend/operation/branch.rs index a0c836a35..00a110010 100644 --- a/crates/cubecl-core/src/frontend/operation/branch.rs +++ b/crates/cubecl-core/src/frontend/operation/branch.rs @@ -50,7 +50,7 @@ pub mod select { let vf = Ord::max(vf, then.vectorization_factor()); let vf = Ord::max(vf, or_else.vectorization_factor()); - let output = context.create_local_binding(then.item.vectorize(NonZero::new(vf))); + let output = context.create_local(then.item.vectorize(NonZero::new(vf))); let out = *output; let select = Operator::Select(Select { diff --git a/crates/cubecl-core/src/frontend/operation/fma.rs b/crates/cubecl-core/src/frontend/operation/fma.rs index d90f879e0..874070ea0 100644 --- a/crates/cubecl-core/src/frontend/operation/fma.rs +++ b/crates/cubecl-core/src/frontend/operation/fma.rs @@ -18,7 +18,7 @@ pub fn fma_expand( b: ExpandElement, c: ExpandElement, ) -> ExpandElement { - let output = context.create_local_binding(a.item); + let output = context.create_local(a.item); let out = *output; let a = *a; diff --git a/crates/cubecl-core/src/frontend/plane.rs b/crates/cubecl-core/src/frontend/plane.rs index 2ad489556..a01590889 100644 --- a/crates/cubecl-core/src/frontend/plane.rs +++ b/crates/cubecl-core/src/frontend/plane.rs @@ -17,7 +17,7 @@ pub mod plane_elect { /// Expand method of [plane_elect()]. pub fn expand(context: &mut CubeContext) -> ExpandElementTyped { - let output = context.create_local_binding(Item::new(Elem::Bool)); + let output = context.create_local(Item::new(Elem::Bool)); let out = *output; context.register(Instruction::new(Plane::Elect, out)); @@ -44,7 +44,7 @@ pub mod plane_broadcast { value: ExpandElementTyped, id: ExpandElementTyped, ) -> ExpandElementTyped { - let output = context.create_local_binding(value.expand.item); + let output = context.create_local(value.expand.item); let out = *output; let lhs = *value.expand; let rhs = *id.expand; @@ -74,7 +74,7 @@ pub mod plane_sum { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -100,7 +100,7 @@ pub mod plane_prod { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -126,7 +126,7 @@ pub mod plane_max { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -152,7 +152,7 @@ pub mod plane_min { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -179,7 +179,7 @@ pub mod plane_all { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -206,7 +206,7 @@ pub mod plane_any { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; diff --git a/crates/cubecl-core/src/ir/allocator.rs b/crates/cubecl-core/src/ir/allocator.rs new file mode 100644 index 000000000..d6f4edfd4 --- /dev/null +++ b/crates/cubecl-core/src/ir/allocator.rs @@ -0,0 +1,89 @@ +use std::{collections::HashMap, rc::Rc}; + +use crate::prelude::ExpandElement; + +use super::{Item, Scope, Variable}; + +/// An allocator for local variables of a kernel. +/// +/// A local variable is unique to a unit. That is, each unit have their own copy of a local variable. +/// There are three types of local variables based on their capabilities. +/// - An immutable local variable is obtained by calling [Allocator::create_local]. +/// - A mutable local variable is obtained by calling [Allocator::create_local_mut]. The allocator will reuse +/// previously defined mutable variables if possible. +/// - A restricted mutable local variable is obtained by calling [Allocator::create_local_restricted]. This a is +/// mutable variable that cannot be reused. This is mostly used for loop indices. +/// +/// # Performance tips +/// +/// In order, prefer immutable local variables, then mutable, then restricted. +/// +/// To enable many compiler optimizations, it is prefered to use the [static single-assignment] strategy for immutable variables. +/// That is, each variable must be declared and used exactly once. +/// +/// [static single-assignment](https://en.wikipedia.org/wiki/Static_single-assignment_form) +#[derive(Clone)] +pub struct Allocator { + local_mut_pool: HashMap>, +} + +impl Default for Allocator { + fn default() -> Self { + Self::new() + } +} + +impl Allocator { + /// Create a new allocator. + pub fn new() -> Self { + Self { + local_mut_pool: HashMap::new(), + } + } + + /// Create a new immutable local variable of type specified by `item` for the given `scope`. + pub fn create_local(&self, scope: &mut Scope, item: Item) -> ExpandElement { + ExpandElement::Plain(scope.create_local(item)) + } + + /// Create a new mutable local variable of type specified by `item` for the given `scope`. + /// Try to reuse a previously defined but unused mutable variable in the current scope if possible. + /// Else, this define a new variable. + pub fn create_local_mut(&mut self, scope: &mut Scope, item: Item) -> ExpandElement { + if item.elem.is_atomic() { + ExpandElement::Plain(scope.create_local_restricted(item)) + } else { + self.reuse_local_mut(item) + .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(scope, item))) + } + } + + /// Create a new mutable restricted local variable of type specified by `item` into the given `scope`. + pub fn create_local_restricted(&self, scope: &mut Scope, item: Item) -> ExpandElement { + ExpandElement::Plain(scope.create_local_restricted(item)) + } + + // Try to return a reusable mutable variable for the given `item` or `None` otherwise. + fn reuse_local_mut(&self, item: Item) -> Option { + // Among the candidates, take a variable if it's only referenced by the pool. + // Arbitrarily takes the first it finds in reversed order. + self.local_mut_pool.get(&item).and_then(|vars| { + vars.iter() + .rev() + .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1)) + .cloned() + }) + } + + /// Add a new variable to the pool with type specified by `item` for the given `scope`. + fn add_local_mut(&mut self, scope: &mut Scope, item: Item) -> Rc { + let var = Rc::new(scope.create_local_mut(item)); + let expand = ExpandElement::Managed(var.clone()); + if let Some(variables) = self.local_mut_pool.get_mut(&item) { + variables.push(expand); + } else { + self.local_mut_pool.insert(var.item, vec![expand]); + } + var + } +} diff --git a/crates/cubecl-core/src/ir/branch.rs b/crates/cubecl-core/src/ir/branch.rs index 5dfea149c..85fd6d120 100644 --- a/crates/cubecl-core/src/ir/branch.rs +++ b/crates/cubecl-core/src/ir/branch.rs @@ -150,7 +150,7 @@ impl RangeLoop { ) { let mut scope = parent_scope.child(); let index_ty = Item::new(Elem::UInt(UIntKind::U32)); - let i = scope.create_local_undeclared(index_ty); + let i = scope.create_local_restricted(index_ty); func(i, &mut scope); diff --git a/crates/cubecl-core/src/ir/local_allocator.rs b/crates/cubecl-core/src/ir/local_allocator.rs deleted file mode 100644 index 5e25febff..000000000 --- a/crates/cubecl-core/src/ir/local_allocator.rs +++ /dev/null @@ -1,143 +0,0 @@ -use std::{ - cell::RefCell, - collections::HashMap, - rc::Rc, - sync::atomic::{AtomicU16, Ordering}, -}; - -use crate::prelude::ExpandElement; - -use super::{Item, Scope, Variable, VariableKind}; - -type ScopeRef = Rc>; - -/// Defines a local variable allocation strategy (i.e. reused mutable variables, SSA) -pub trait LocalAllocator { - /// Creates a local variable that can be (re)assigned - fn create_local_variable(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement; - /// Creates an immutable local binding for intermediates - fn create_local_binding(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement; - /// Creates an undeclared local binding that must not be reused regardless of allocator - fn create_local_undeclared(&self, root: ScopeRef, scope: ScopeRef, item: Item) - -> ExpandElement; -} - -#[derive(Default, Clone)] -pub struct VariablePool { - map: Rc>>>, -} - -impl VariablePool { - /// Returns an old, not used anymore variable, if there exists one. - pub fn reuse(&self, item: Item) -> Option { - let map = self.map.borrow(); - - // Filter for candidate variables of the same Item - let variables = map.get(&item)?; - - // Among the candidates, take a variable if it's only referenced by the map - // Arbitrarily takes the first it finds in reverse order. - for variable in variables.iter().rev() { - match variable { - ExpandElement::Managed(var) => { - if Rc::strong_count(var) == 1 { - return Some(variable.clone()); - } - } - ExpandElement::Plain(_) => (), - } - } - - // If no candidate was found, a new var will be needed - None - } - - /// Insert a new variable in the map, which is classified by Item - pub fn insert(&self, var: ExpandElement) { - let mut map = self.map.borrow_mut(); - let item = var.item; - - if let Some(variables) = map.get_mut(&item) { - variables.push(var.clone()); - } else { - map.insert(var.item, vec![var.clone()]); - } - } -} - -/// Reusing allocator, assigns all intermediates to a set of mutable variables that get continuously -/// reused. -#[derive(Default)] -pub struct ReusingAllocator { - pool: VariablePool, -} - -impl LocalAllocator for ReusingAllocator { - fn create_local_variable(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - if item.elem.is_atomic() { - let new = scope.borrow_mut().create_local_undeclared(item); - return ExpandElement::Plain(new); - } - - // Reuse an old variable if possible - if let Some(var) = self.pool.reuse(item) { - return var; - } - - // Create a new variable at the root scope - // Insert it in the variable pool for potential reuse - let new = ExpandElement::Managed(Rc::new(root.borrow_mut().create_local(item))); - self.pool.insert(new.clone()); - - new - } - - fn create_local_binding(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - self.create_local_variable(root, scope, item) - } - - fn create_local_undeclared( - &self, - _root: ScopeRef, - scope: ScopeRef, - item: Item, - ) -> ExpandElement { - ExpandElement::Plain(scope.borrow_mut().create_local_undeclared(item)) - } -} - -/// Hybrid allocator. Creates immutable local bindings for intermediates, and falls back to -/// [`ReusingAllocator`] for mutable variables. -#[derive(Default)] -pub struct HybridAllocator { - variable_allocator: ReusingAllocator, - ssa_index: AtomicU16, -} - -impl LocalAllocator for HybridAllocator { - fn create_local_variable(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - self.ssa_index.fetch_add(1, Ordering::AcqRel); - self.variable_allocator - .create_local_variable(root, scope, item) - } - - fn create_local_binding(&self, _root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - let id = self.ssa_index.fetch_add(1, Ordering::AcqRel); - let depth = scope.borrow().depth; - ExpandElement::Plain(Variable::new( - VariableKind::LocalBinding { id, depth }, - item, - )) - } - - fn create_local_undeclared( - &self, - _root: ScopeRef, - scope: ScopeRef, - item: Item, - ) -> ExpandElement { - let id = self.ssa_index.fetch_add(1, Ordering::AcqRel); - let depth = scope.borrow().depth; - ExpandElement::Plain(Variable::new(VariableKind::Local { id, depth }, item)) - } -} diff --git a/crates/cubecl-core/src/ir/mod.rs b/crates/cubecl-core/src/ir/mod.rs index 820d80586..281488139 100644 --- a/crates/cubecl-core/src/ir/mod.rs +++ b/crates/cubecl-core/src/ir/mod.rs @@ -1,9 +1,9 @@ +mod allocator; mod branch; mod cmma; mod comment; mod debug; mod kernel; -mod local_allocator; mod macros; mod operation; mod plane; @@ -13,12 +13,12 @@ mod synchronization; mod variable; pub use super::frontend::AtomicOp; +pub use allocator::*; pub use branch::*; pub use cmma::*; pub use comment::*; pub use debug::*; pub use kernel::*; -pub use local_allocator::*; pub use operation::*; pub use plane::*; pub use scope::*; diff --git a/crates/cubecl-core/src/ir/scope.rs b/crates/cubecl-core/src/ir/scope.rs index 4bcfd3c23..9ed3923bc 100644 --- a/crates/cubecl-core/src/ir/scope.rs +++ b/crates/cubecl-core/src/ir/scope.rs @@ -67,7 +67,7 @@ impl Scope { /// Create a variable initialized at zero. pub fn zero>(&mut self, item: I) -> Variable { - let local = self.create_local(item); + let local = self.create_local(item.into()); let zero: Variable = 0u32.into(); cpa!(self, local = zero); local @@ -123,12 +123,12 @@ impl Scope { variable } - /// Create a local variable of the given [item type](Item). - pub fn create_local>(&mut self, item: I) -> Variable { + /// Create a mutable variable of the given [item type](Item). + pub fn create_local_mut>(&mut self, item: I) -> Variable { let item = item.into(); let index = self.new_local_index(); let local = Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: index, depth: self.depth, }, @@ -138,12 +138,11 @@ impl Scope { local } - /// Create a new undeclared local, but doesn't perform the declaration. - /// + /// Create a new restricted variable. The variable is /// Useful for _for loops_ and other algorithms that require the control over initialization. - pub fn create_local_undeclared(&mut self, item: Item) -> Variable { + pub fn create_local_restricted(&mut self, item: Item) -> Variable { let index = self.new_local_index(); - let local = VariableKind::Local { + let local = VariableKind::LocalMut { id: index, depth: self.depth, }; @@ -151,12 +150,10 @@ impl Scope { Variable::new(local, item) } - /// Create a new undeclared local binding, but doesn't perform the declaration. - /// - /// Useful for temporaries and other algorithms that require the control over initialization. - pub fn create_local_binding(&mut self, item: Item) -> Variable { + /// Create a new immutable variable. + pub fn create_local(&mut self, item: Item) -> Variable { let index = self.new_local_index(); - let local = VariableKind::LocalBinding { + let local = VariableKind::LocalConst { id: index, depth: self.depth, }; @@ -181,7 +178,7 @@ impl Scope { /// The index refers to the scalar position for the same [element](Elem) type. pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable { let local = Variable::new( - VariableKind::LocalBinding { + VariableKind::LocalConst { id: self.new_local_index(), depth: self.depth, }, @@ -344,7 +341,7 @@ impl Scope { let input = Variable::new(VariableKind::GlobalInputArray(index), item_global); let index = self.new_local_index(); let local = Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: index, depth: self.depth, }, diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index eb22d3058..41317acab 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -41,13 +41,13 @@ pub enum VariableKind { GlobalInputArray(Id), GlobalOutputArray(Id), GlobalScalar(Id), - Local { id: Id, depth: u8 }, + LocalArray { id: Id, depth: u8, length: u32 }, + LocalMut { id: Id, depth: u8 }, + LocalConst { id: Id, depth: u8 }, Versioned { id: Id, depth: u8, version: u16 }, - LocalBinding { id: Id, depth: u8 }, ConstantScalar(ConstantScalarValue), ConstantArray { id: Id, length: u32 }, SharedMemory { id: Id, length: u32 }, - LocalArray { id: Id, depth: u8, length: u32 }, Matrix { id: Id, mat: Matrix, depth: u8 }, Slice { id: Id, depth: u8 }, Builtin(Builtin), @@ -85,7 +85,7 @@ impl Variable { pub fn is_immutable(&self) -> bool { match self.kind { VariableKind::GlobalOutputArray { .. } => false, - VariableKind::Local { .. } => false, + VariableKind::LocalMut { .. } => false, VariableKind::SharedMemory { .. } => false, VariableKind::Matrix { .. } => false, VariableKind::Slice { .. } => false, @@ -93,7 +93,7 @@ impl Variable { VariableKind::GlobalInputArray { .. } => false, VariableKind::GlobalScalar { .. } => true, VariableKind::Versioned { .. } => true, - VariableKind::LocalBinding { .. } => true, + VariableKind::LocalConst { .. } => true, VariableKind::ConstantScalar(_) => true, VariableKind::ConstantArray { .. } => true, VariableKind::Builtin(_) => true, @@ -369,21 +369,33 @@ impl Variable { pub fn vectorization_factor(&self) -> u8 { self.item.vectorization.map(NonZero::get).unwrap_or(1u8) } - pub fn index(&self) -> Option { + + pub fn index(&self) -> Option { match self.kind { - VariableKind::GlobalInputArray(id) => Some(id), - VariableKind::GlobalScalar(id) => Some(id), - VariableKind::Local { id, .. } => Some(id), - VariableKind::Versioned { id, .. } => Some(id), - VariableKind::LocalBinding { id, .. } => Some(id), - VariableKind::Slice { id, .. } => Some(id), - VariableKind::GlobalOutputArray(id) => Some(id), - VariableKind::ConstantScalar(_) => None, - VariableKind::ConstantArray { id, .. } => Some(id), - VariableKind::SharedMemory { id, .. } => Some(id), - VariableKind::LocalArray { id, .. } => Some(id), - VariableKind::Matrix { id, .. } => Some(id), - VariableKind::Builtin(_) => None, + VariableKind::GlobalInputArray(id) + | VariableKind::GlobalScalar(id) + | VariableKind::LocalMut { id, .. } + | VariableKind::Versioned { id, .. } + | VariableKind::LocalConst { id, .. } + | VariableKind::Slice { id, .. } + | VariableKind::GlobalOutputArray(id) + | VariableKind::ConstantArray { id, .. } + | VariableKind::SharedMemory { id, .. } + | VariableKind::LocalArray { id, .. } + | VariableKind::Matrix { id, .. } => Some(id), + _ => None, + } + } + + pub fn depth(&self) -> Option { + match self.kind { + VariableKind::LocalMut { depth, .. } + | VariableKind::Versioned { depth, .. } + | VariableKind::LocalConst { depth, .. } + | VariableKind::LocalArray { depth, .. } + | VariableKind::Matrix { depth, .. } + | VariableKind::Slice { depth, .. } => Some(depth), + _ => None, } } @@ -402,11 +414,11 @@ impl Display for Variable { VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"), VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"), VariableKind::ConstantScalar(constant) => write!(f, "{constant}"), - VariableKind::Local { id, depth } => write!(f, "local({id}, {depth})"), + VariableKind::LocalMut { id, depth } => write!(f, "local({id}, {depth})"), VariableKind::Versioned { id, depth, version } => { write!(f, "local({id}, {depth}).v{version}") } - VariableKind::LocalBinding { id, depth } => write!(f, "binding({id}, {depth})"), + VariableKind::LocalConst { id, depth } => write!(f, "binding({id}, {depth})"), VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"), VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"), VariableKind::LocalArray { id, .. } => write!(f, "array({id})"), diff --git a/crates/cubecl-core/src/runtime_tests/line.rs b/crates/cubecl-core/src/runtime_tests/line.rs index 8f9370630..6454cd36b 100644 --- a/crates/cubecl-core/src/runtime_tests/line.rs +++ b/crates/cubecl-core/src/runtime_tests/line.rs @@ -66,6 +66,44 @@ pub fn test_line_index_assign( } } +#[cube(launch_unchecked)] +pub fn kernel_line_loop_unroll(output: &mut Array>, #[comptime] line_size: u32) { + if UNIT_POS == 0 { + let mut line = output[0]; + #[unroll] + for k in 0..line_size { + line[k] += F::cast_from(k); + } + output[0] = line; + } +} + +pub fn test_line_loop_unroll( + client: ComputeClient, +) { + for line_size in R::line_size_elem(&F::as_elem(&Default::default())) { + let handle = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); + unsafe { + kernel_line_loop_unroll::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_single(), + ArrayArg::from_raw_parts::(&handle, 1, line_size), + line_size as u32, + ); + } + + let actual = client.read_one(handle.binding()); + let actual = F::from_bytes(&actual); + + let expected = (0..line_size as i64) + .map(|x| F::from_int(x)) + .collect::>(); + + assert_eq!(actual, expected); + } +} + macro_rules! impl_line_comparison { ($cmp:ident, $expected:expr) => { ::paste::paste! { @@ -134,6 +172,14 @@ macro_rules! testgen_line { ); } + #[test] + fn test_line_loop_unroll() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::line::test_line_loop_unroll::( + client, + ); + } + #[test] fn test_line_equal() { let client = TestRuntime::client(&Default::default()); diff --git a/crates/cubecl-cpp/src/cuda/dialect.rs b/crates/cubecl-cpp/src/cuda/dialect.rs index 4da8c82bc..136243b73 100644 --- a/crates/cubecl-cpp/src/cuda/dialect.rs +++ b/crates/cubecl-cpp/src/cuda/dialect.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::shared::{Dialect, IndexedVariable, Variable, WmmaCompiler}; +use crate::shared::{Dialect, WmmaCompiler}; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct CudaDialect { @@ -74,20 +74,19 @@ impl> Dialect for CudaDialect { fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("__nv_bfloat162") } - - fn warp_shuffle(input: &IndexedVariable, id: &Variable) -> String { - format!("__shfl_sync(-1, {input}, {id})") + fn warp_shuffle(var: &str, source: &str) -> String { + format!("__shfl_sync(-1, {var}, {source})") } - fn warp_shuffle_xor(out: &IndexedVariable) -> String { - format!("__shfl_xor_sync(-1, {out}, offset)") + fn warp_shuffle_xor(var: &str, offset: &str) -> String { + format!("__shfl_xor_sync(-1, {var}, {offset})") } - fn warp_shuffle_down(out: &IndexedVariable) -> String { - format!("__shfl_down_sync(-1, {out}, offset)") + fn warp_shuffle_down(var: &str, offset: &str) -> String { + format!("__shfl_down_sync(-1, {var}, {offset})") } - fn warp_all(out: &IndexedVariable) -> String { - format!("__all_sync(-1, {out})") + fn warp_all(var: &str) -> String { + format!("__all_sync(-1, {var})") } - fn warp_any(out: &IndexedVariable) -> String { - format!("__any_sync(-1, {out})") + fn warp_any(var: &str) -> String { + format!("__any_sync(-1, {var})") } } diff --git a/crates/cubecl-cpp/src/hip/dialect.rs b/crates/cubecl-cpp/src/hip/dialect.rs index 38d0c4d5e..2f5ad02eb 100644 --- a/crates/cubecl-cpp/src/hip/dialect.rs +++ b/crates/cubecl-cpp/src/hip/dialect.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::shared::{Dialect, IndexedVariable, Variable, WmmaCompiler}; +use crate::shared::{Dialect, WmmaCompiler}; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct HipDialect { @@ -75,20 +75,19 @@ impl> Dialect for HipDialect { // "hip_bfloat16.h" has no "hip_bfloat162" type f.write_str("hip_bfloat16") } - - fn warp_shuffle(input: &IndexedVariable, id: &Variable) -> String { - format!("__shfl({input}, {id})") + fn warp_shuffle(var: &str, source: &str) -> String { + format!("__shfl({var}, {source})") } - fn warp_shuffle_xor(out: &IndexedVariable) -> String { - format!("__shfl_xor({out}, offset)") + fn warp_shuffle_xor(var: &str, offset: &str) -> String { + format!("__shfl_xor_sync({var}, {offset})") } - fn warp_shuffle_down(out: &IndexedVariable) -> String { - format!("__shfl_down({out}, offset)") + fn warp_shuffle_down(var: &str, offset: &str) -> String { + format!("__shfl_down_sync({var}, {offset})") } - fn warp_all(out: &IndexedVariable) -> String { - format!("__all({out})") + fn warp_all(var: &str) -> String { + format!("__all({var})") } - fn warp_any(out: &IndexedVariable) -> String { + fn warp_any(out: &str) -> String { format!("__any({out})") } } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 830dd12ef..2be8a197f 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1,7 +1,7 @@ use std::hash::Hash; use std::{collections::HashSet, fmt::Debug, num::NonZero}; -use cubecl_core::ir::expand_checked_index_assign; +use cubecl_core::ir::{expand_checked_index_assign, Allocator}; use cubecl_core::{ ir::{self as gpu}, Compiler, Feature, @@ -10,8 +10,8 @@ use cubecl_runtime::{DeviceProperties, ExecutionMode}; use super::{ AtomicKind, BinaryInstruction, Binding, Body, ComputeKernel, ConstArray, Elem, Fragment, - FragmentIdent, FragmentLayout, IndexedVariable, Instruction, Item, LocalArray, SharedMemory, - UnaryInstruction, Variable, VariableSettings, WarpInstruction, WmmaCompiler, WmmaInstruction, + FragmentIdent, FragmentLayout, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction, + Variable, VariableSettings, WarpInstruction, WmmaCompiler, WmmaInstruction, }; pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 = @@ -28,11 +28,11 @@ pub trait Dialect: fn bfloat16_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; // warp instructions (all threads participating) - fn warp_shuffle(input: &IndexedVariable, id: &Variable) -> String; - fn warp_shuffle_xor(out: &IndexedVariable) -> String; - fn warp_shuffle_down(out: &IndexedVariable) -> String; - fn warp_all(out: &IndexedVariable) -> String; - fn warp_any(out: &IndexedVariable) -> String; + fn warp_shuffle(var: &str, source: &str) -> String; + fn warp_shuffle_xor(var: &str, offset: &str) -> String; + fn warp_shuffle_down(var: &str, offset: &str) -> String; + fn warp_all(var: &str) -> String; + fn warp_any(var: &str) -> String; } #[derive(Clone, Debug)] @@ -94,8 +94,8 @@ impl Compiler for CppCompiler { 49152 } - fn local_allocator() -> impl gpu::LocalAllocator { - gpu::ReusingAllocator::default() + fn local_allocator() -> Allocator { + Allocator::new() } } @@ -546,8 +546,7 @@ impl CppCompiler { if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() { let input = op.input; let input_len = - scope.create_local(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32))); - + scope.create_local_mut(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32))); instructions.extend(self.compile_scope(scope)); let length = match input.has_buffer_length() { @@ -822,17 +821,17 @@ impl CppCompiler { gpu::VariableKind::GlobalScalar(id) => { Variable::GlobalScalar(id, self.compile_item(item).elem, item.elem) } - gpu::VariableKind::Local { id, depth } => Variable::Local { + gpu::VariableKind::LocalMut { id, depth } => Variable::LocalMut { id, item: self.compile_item(item), depth, }, - gpu::VariableKind::Versioned { id, depth, .. } => Variable::Local { + gpu::VariableKind::Versioned { id, depth, .. } => Variable::LocalMut { id, item: self.compile_item(item), depth, }, - gpu::VariableKind::LocalBinding { id, depth } => Variable::ConstLocal { + gpu::VariableKind::LocalConst { id, depth } => Variable::LocalConst { id, item: self.compile_item(item), depth, diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index 29d2f2349..f8e8d8512 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -276,7 +276,7 @@ impl Binary for IndexAssign { rhs: &Variable, out: &Variable, ) -> std::fmt::Result { - if matches!(out, Variable::Local { .. } | Variable::ConstLocal { .. }) { + if matches!(out, Variable::LocalMut { .. } | Variable::LocalConst { .. }) { return IndexAssignVector::format(f, lhs, rhs, out); }; @@ -299,7 +299,7 @@ impl Binary for Index { rhs: &Variable, out: &Variable, ) -> std::fmt::Result { - if matches!(lhs, Variable::Local { .. } | Variable::ConstLocal { .. }) { + if matches!(lhs, Variable::LocalMut { .. } | Variable::LocalConst { .. }) { return IndexVector::format(f, lhs, rhs, out); } @@ -386,8 +386,12 @@ impl IndexVector { Variable::ConstantScalar(value, _elem) => value.as_usize(), _ => { let elem = out.elem(); + let qualifier = out.const_qualifier(); let out = out.fmt_left(); - return writeln!(f, "{out} = reinterpret_cast<{elem}*>(&{lhs})[{rhs}];"); + return writeln!( + f, + "{out} = reinterpret_cast<{elem}{qualifier}*>(&{lhs})[{rhs}];" + ); } }; diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index cd80b2b70..d6a28d0f4 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -104,7 +104,7 @@ impl Component for IndexedVariable { } fn is_const(&self) -> bool { - matches!(self.var, Variable::ConstLocal { .. }) + matches!(self.var, Variable::LocalConst { .. }) } } @@ -119,8 +119,9 @@ impl Component for Variable { Variable::GlobalOutputArray(_, e) => *e, Variable::SharedMemory(_, e, _) => *e, Variable::ConstantArray(_, e, _) => *e, - Variable::Local { item, .. } => *item, - Variable::ConstLocal { item, .. } => *item, + Variable::LocalMut { item, .. } => *item, + Variable::LocalConst { item, .. } => *item, + Variable::Named { item, .. } => *item, Variable::Slice { item, .. } => *item, Variable::ConstantScalar(_, e) => Item::scalar(*e), Variable::GlobalScalar(_, e, _) => Item::scalar(*e), @@ -157,7 +158,7 @@ impl Component for Variable { } fn is_const(&self) -> bool { - matches!(self, Variable::ConstLocal { .. }) + matches!(self, Variable::LocalConst { .. }) } } @@ -170,16 +171,20 @@ pub enum Variable { GlobalScalar(u16, Elem, gpu::Elem), ConstantArray(u16, Item, u32), ConstantScalar(ConstantScalarValue, Elem), - Local { + LocalMut { id: u16, item: Item, depth: u8, }, - ConstLocal { + LocalConst { id: u16, item: Item, depth: u8, }, + Named { + name: &'static str, + item: Item, + }, Slice { id: u16, item: Item, @@ -222,8 +227,9 @@ impl Display for Variable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), - Variable::Local { id, depth, .. } => f.write_fmt(format_args!("l_{depth}_{id}")), - Variable::ConstLocal { id, depth, .. } => f.write_fmt(format_args!("ssa_{depth}_{id}")), + Variable::LocalMut { id, depth, .. } => f.write_fmt(format_args!("l_mut_{depth}_{id}")), + Variable::LocalConst { id, depth, .. } => f.write_fmt(format_args!("l_{depth}_{id}")), + Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")), Variable::Slice { id, item: _, depth } => { write!(f, "slice_{depth}_{id}") } @@ -362,12 +368,12 @@ impl Variable { Variable::GlobalOutputArray(id, item) => { Variable::GlobalOutputArray(*id, item.optimized()) } - Variable::Local { id, item, depth } => Variable::Local { + Variable::LocalMut { id, item, depth } => Variable::LocalMut { id: *id, item: item.optimized(), depth: *depth, }, - Variable::ConstLocal { id, item, depth } => Variable::ConstLocal { + Variable::LocalConst { id, item, depth } => Variable::LocalConst { id: *id, item: item.optimized(), depth: *depth, @@ -410,8 +416,9 @@ impl Variable { Variable::GlobalOutputArray(_, _) => false, Variable::SharedMemory(_, _, _) => false, Variable::ConstantArray(_, _, _) => false, - Variable::Local { .. } => false, - Variable::ConstLocal { .. } => false, + Variable::LocalMut { .. } => false, + Variable::LocalConst { .. } => false, + Variable::Named { .. } => false, Variable::Slice { .. } => false, Variable::BlockIdxX => true, Variable::BlockIdxY => true, @@ -443,6 +450,14 @@ impl Variable { optimized: self.is_optimized(), } } + + pub fn const_qualifier(&self) -> &str { + if self.is_const() { + " const" + } else { + "" + } + } } pub trait FmtLeft: Display { @@ -452,7 +467,7 @@ pub trait FmtLeft: Display { impl FmtLeft for Variable { fn fmt_left(&self) -> String { match self { - Self::ConstLocal { item, .. } => format!("const {item} {self}"), + Self::LocalConst { item, .. } => format!("const {item} {self}"), Variable::Tmp { item, .. } => format!("{item} {self}"), var => format!("{var}"), } @@ -462,7 +477,7 @@ impl FmtLeft for Variable { impl FmtLeft for IndexedVariable { fn fmt_left(&self) -> String { match self.var { - Variable::ConstLocal { item, .. } => format!("const {item} {self}"), + Variable::LocalConst { item, .. } => format!("const {item} {self}"), Variable::Tmp { item, .. } => format!("{item} {self}"), _ => format!("{self}"), } @@ -485,7 +500,7 @@ pub struct IndexedVariable { impl Display for IndexedVariable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let var = &self.var; - let ref_ = matches!(var, Variable::ConstLocal { .. }) + let ref_ = matches!(var, Variable::LocalConst { .. }) .then_some("const&") .unwrap_or("&"); diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index c5c79f1b7..2c58b2b15 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -729,11 +729,12 @@ impl Remainder { write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?; + let qualifier = out.const_qualifier(); let out = out.fmt_left(); writeln!( f, - "{out} = reinterpret_cast<{item_out_original}&>({out_tmp});\n" + "{out} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n" )?; Ok(()) diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 94fd8083e..3849a9388 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -62,10 +62,11 @@ pub trait Unary { let out_tmp = Variable::tmp(item_out_optimized); write_op(index, elem, &input, &out_tmp)?; - + let qualifier = out.const_qualifier(); + let out_fmt = out.fmt_left(); writeln!( f, - "{out} = reinterpret_cast<{item_out_original}&>({out_tmp});\n" + "{out_fmt} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n" ) } else { write_op(index, elem, &input, &out_optimized) diff --git a/crates/cubecl-cpp/src/shared/warp.rs b/crates/cubecl-cpp/src/shared/warp.rs index 4023d3e5e..3043f5e1e 100644 --- a/crates/cubecl-cpp/src/shared/warp.rs +++ b/crates/cubecl-cpp/src/shared/warp.rs @@ -1,8 +1,8 @@ use std::fmt::Display; -use crate::shared::{Component, Elem}; +use crate::shared::{Component, Elem, FmtLeft}; -use super::{Dialect, IndexedVariable, Variable}; +use super::{Dialect, Item, Variable}; #[derive(Clone, Debug)] pub enum WarpInstruction { @@ -47,6 +47,9 @@ impl Display for WarpInstruction { WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="), WarpInstruction::ReduceMax { input, out } => reduce_comparison(f, input, out, "max"), WarpInstruction::ReduceMin { input, out } => reduce_comparison(f, input, out, "min"), + WarpInstruction::All { input, out } => reduce_quantifier(f, input, out, D::warp_all), + WarpInstruction::Any { input, out } => reduce_quantifier(f, input, out, D::warp_any), + WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id), WarpInstruction::Elect { out } => write!( f, " @@ -55,27 +58,6 @@ unsigned int leader = __ffs(mask) - 1; {out} = threadIdx.x % warpSize == leader; " ), - WarpInstruction::All { input, out } => reduce_quantifier(f, input, out, D::warp_all), - WarpInstruction::Any { input, out } => reduce_quantifier(f, input, out, D::warp_any), - WarpInstruction::Broadcast { input, id, out } => { - let input_optimized = input.optimized(); - let out_optimized = out.optimized(); - for k in 0..out_optimized.item().vectorization { - let __shfl = D::warp_shuffle(&input_optimized.index(k), id); - let indexed = out_optimized.index(k); - write!( - f, - " - {{ - for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{ - {indexed} = {__shfl}; - }} - }} - " - )?; - } - Ok(()) - } } } } @@ -86,30 +68,14 @@ fn reduce_operator( out: &Variable, op: &str, ) -> core::fmt::Result { - write!( - f, - " - {out} = {input}; - " - )?; + let in_optimized = input.optimized(); + let acc_item = in_optimized.item(); - let optimized = out.optimized(); - - for k in 0..optimized.item().vectorization { - let indexed = optimized.index(k); - let __shfl_xor = D::warp_shuffle_xor(&indexed); - write!( - f, - " - {{ - for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{ - {indexed} {op} {__shfl_xor}; - }} - }} - " - )?; - } - Ok(()) + reduce_with_loop(f, input, out, acc_item, |acc, index| { + let acc_indexed = maybe_index(acc, index); + let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset"); + format!("{acc_indexed} {op} {shfl_xor};") + }) } fn reduce_comparison( @@ -118,64 +84,89 @@ fn reduce_comparison( out: &Variable, cmp: &str, ) -> core::fmt::Result { - write!( - f, - " - {out} = {input}; - " - )?; - - let optimized = out.optimized(); - - let instruction = match optimized.elem() { + let in_optimized = input.optimized(); + let acc_item = in_optimized.item(); + let instruction = match in_optimized.elem() { Elem::F16 | Elem::BF16 => format!("__h{cmp}"), Elem::F162 | Elem::BF162 => format!("__h{cmp}2"), _ => cmp.to_string(), }; - for k in 0..optimized.item().vectorization { - let indexed = optimized.index(k); - let __shfl_xor = D::warp_shuffle_xor(&indexed); - write!( - f, - " - {{ - for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ - {indexed} = {instruction}({indexed}, {__shfl_xor}); - }} - }} - " - )?; - } - Ok(()) + reduce_with_loop(f, input, out, acc_item, |acc, index| { + let acc_indexed = maybe_index(acc, index); + let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset"); + format!("{acc_indexed} = {instruction}({acc_indexed}, {shfl_xor});") + }) } -fn reduce_quantifier) -> String>( +fn reduce_broadcast( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, - quantifier: Q, + id: &Variable, ) -> core::fmt::Result { - write!( + let rhs = (0..input.item().vectorization) + .map(|k| D::warp_shuffle(&format!("{}", input.index(k)), &format!("{id}"))) + .collect::>() + .join(","); + let out_fmt = out.fmt_left(); + writeln!(f, "{out_fmt} = {{ {rhs} }};") +} + +fn reduce_with_loop, usize) -> String>( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + acc_item: Item, + instruction: I, +) -> core::fmt::Result { + let acc = Variable::Named { + name: "acc", + item: acc_item, + }; + + writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?; + writeln!(f, " {} {} = {};", acc_item, acc, cast(input, acc_item))?; + writeln!( f, - " - {out} = {input}; - " + " for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{" )?; - let optimized = out.optimized(); - for k in 0..optimized.item().vectorization { - let indexed = optimized.index(k); - let __all = quantifier(&indexed); - write!( - f, - " - {{ - for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{ - {indexed} = {__all}; - }} - }} - " - )?; + for k in 0..acc_item.vectorization { + writeln!(f, " {}", instruction(&acc, k))?; + } + writeln!(f, " }};")?; + writeln!(f, " return {};", cast(&acc, out.item()))?; + writeln!(f, "}};")?; + writeln!(f, "{} = plane_{}();", out.fmt_left(), out) +} + +fn reduce_quantifier String>( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + quantifier: Q, +) -> core::fmt::Result { + let rhs = (0..input.item().vectorization) + .map(|k| quantifier(&format!("{}", input.index(k)))) + .collect::>() + .join(","); + let out_fmt = out.fmt_left(); + writeln!(f, "{out_fmt} = {{ {rhs} }};") +} + +fn cast(input: &Variable, target: Item) -> String { + if target != input.item() { + let qualifier = input.const_qualifier(); + format!("reinterpret_cast<{}{}&>({})", target, qualifier, input) + } else { + format!("{}", input) + } +} + +fn maybe_index(var: &Variable, k: usize) -> String { + if var.item().vectorization > 1 { + format!("{var}.i_{k}") + } else { + format!("{var}") } - Ok(()) } diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index e1f9523ce..48ec1248a 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -253,7 +253,7 @@ impl Optimizer { let step = range_loop.step.unwrap_or(1.into()); let i_id = match range_loop.i.kind { - VariableKind::Local { id, depth, .. } => (id, depth), + VariableKind::LocalMut { id, depth, .. } => (id, depth), _ => unreachable!(), }; let i = range_loop.i; diff --git a/crates/cubecl-opt/src/gvn/convert.rs b/crates/cubecl-opt/src/gvn/convert.rs index 8bfc84a2d..ee4981692 100644 --- a/crates/cubecl-opt/src/gvn/convert.rs +++ b/crates/cubecl-opt/src/gvn/convert.rs @@ -221,7 +221,7 @@ impl Value { version: 0, item, }) => Variable::new( - VariableKind::LocalBinding { + VariableKind::LocalConst { id: *id, depth: *depth, }, @@ -276,7 +276,7 @@ pub fn value_of_var(var: &Variable) -> Option { version, item, }), - VariableKind::LocalBinding { id, depth } => Value::Local(Local { + VariableKind::LocalConst { id, depth } => Value::Local(Local { id, depth, version: 0, @@ -284,7 +284,7 @@ pub fn value_of_var(var: &Variable) -> Option { }), VariableKind::ConstantScalar(val) => Value::Constant(val.into()), VariableKind::ConstantArray { id, length } => Value::ConstArray(id, item, length), - VariableKind::Local { .. } + VariableKind::LocalMut { .. } | VariableKind::SharedMemory { .. } | VariableKind::LocalArray { .. } | VariableKind::Matrix { .. } => None?, diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index 51815a8ea..088cdff94 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -320,7 +320,7 @@ impl Optimizer { let ops = self.program[node].ops.clone(); for op in ops.borrow().values() { if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation { - if let VariableKind::Local { id, depth } = &op.out().kind { + if let VariableKind::LocalMut { id, depth } = &op.out().kind { self.program.variables.remove(&(*id, *depth)); } } @@ -382,7 +382,7 @@ impl Optimizer { let processed = scope.process(); for var in processed.variables { - if let VariableKind::Local { id, depth } = var.kind { + if let VariableKind::LocalMut { id, depth } = var.kind { self.program.variables.insert((id, depth), var.item); } } @@ -437,7 +437,7 @@ impl Optimizer { /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise. pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<(u16, u8)> { match variable.kind { - core::VariableKind::Local { id, depth } if !variable.item.elem.is_atomic() => { + core::VariableKind::LocalMut { id, depth } if !variable.item.elem.is_atomic() => { Some((id, depth)) } _ => None, @@ -449,7 +449,7 @@ impl Optimizer { pub fn create_temporary(&self, item: Item) -> Variable { let next_id = self.program.temp_id.inc() as u16; Variable::new( - VariableKind::LocalBinding { + VariableKind::LocalConst { id: u16::MAX - next_id, depth: u8::MAX, }, @@ -481,7 +481,7 @@ mod test { use cubecl::prelude::*; use cubecl_core::{ self as cubecl, - ir::{Elem, HybridAllocator, Item, UIntKind, Variable, VariableKind}, + ir::{Allocator, Elem, Item, UIntKind, Variable, VariableKind}, prelude::{Array, CubeContext, ExpandElement}, }; use cubecl_core::{cube, CubeDim, ExecutionMode}; @@ -504,7 +504,7 @@ mod test { #[test] #[ignore = "no good way to assert opt is applied"] fn test_pre() { - let mut ctx = CubeContext::root(HybridAllocator::default()); + let mut ctx = CubeContext::root(Allocator::new()); let x = ExpandElement::Plain(Variable::new( VariableKind::GlobalScalar(0), Item::new(Elem::UInt(UIntKind::U32)), diff --git a/crates/cubecl-opt/src/passes/array_copy_propagate.rs b/crates/cubecl-opt/src/passes/array_copy_propagate.rs index a1e98ed55..b913bdfe4 100644 --- a/crates/cubecl-opt/src/passes/array_copy_propagate.rs +++ b/crates/cubecl-opt/src/passes/array_copy_propagate.rs @@ -34,7 +34,7 @@ impl OptimizerPass for CopyPropagateArray { for Array { id, length, item } in arrays { let arr_id = id; let vars = (0..length) - .map(|_| opt.root_scope.create_local_undeclared(item)) + .map(|_| opt.root_scope.create_local_restricted(item)) .collect::>(); for var in &vars { let local_id = opt.local_variable_id(var).unwrap(); diff --git a/crates/cubecl-opt/src/passes/composite.rs b/crates/cubecl-opt/src/passes/composite.rs index c6ff4fe79..dc56fb11b 100644 --- a/crates/cubecl-opt/src/passes/composite.rs +++ b/crates/cubecl-opt/src/passes/composite.rs @@ -46,7 +46,7 @@ impl OptimizerPass for CompositeMerge { let op = { ops.borrow()[idx].clone() }; if let ( Operation::Operator(Operator::IndexAssign(BinaryOperator { lhs, rhs })), - Some(VariableKind::Local { id, depth }), + Some(VariableKind::LocalMut { id, depth }), ) = (op.operation, op.out.map(|it| it.kind)) { let item = op.out.unwrap().item; @@ -70,7 +70,7 @@ impl OptimizerPass for CompositeMerge { assert_eq!(index, 0, "Can't index into scalar"); opt.program[block].ops.borrow_mut()[idx] = Instruction::new( Operation::Copy(rhs), - Variable::new(VariableKind::Local { id, depth }, item), + Variable::new(VariableKind::LocalMut { id, depth }, item), ) } } @@ -93,7 +93,7 @@ fn merge_assigns( let last = assigns.iter().map(|it| it.0).max().unwrap(); assigns.sort_by_key(|it| it.1); let inputs = assigns.iter().map(|it| it.2).collect::>(); - let out = Variable::new(VariableKind::Local { id, depth }, item); + let out = Variable::new(VariableKind::LocalMut { id, depth }, item); ops.insert( last, Instruction::new(Operator::InitLine(LineInitOperator { inputs }), out), diff --git a/crates/cubecl-opt/src/passes/index_merge.rs b/crates/cubecl-opt/src/passes/index_merge.rs index 4b3e474ef..5053f5a0c 100644 --- a/crates/cubecl-opt/src/passes/index_merge.rs +++ b/crates/cubecl-opt/src/passes/index_merge.rs @@ -62,7 +62,7 @@ impl OptimizerPass for CopyTransform { fn as_versioned(var: &Variable) -> Option<(u16, u8, u16)> { match var.kind { - VariableKind::LocalBinding { id, depth } => Some((id, depth, 0)), + VariableKind::LocalConst { id, depth } => Some((id, depth, 0)), VariableKind::Versioned { id, depth, version } => Some((id, depth, version)), _ => None, } diff --git a/crates/cubecl-opt/src/passes/integer_range_analysis.rs b/crates/cubecl-opt/src/passes/integer_range_analysis.rs index 712f4998d..2eced266e 100644 --- a/crates/cubecl-opt/src/passes/integer_range_analysis.rs +++ b/crates/cubecl-opt/src/passes/integer_range_analysis.rs @@ -126,9 +126,7 @@ pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { .get(&(id, depth, version)) .copied() .unwrap_or_default(), - VariableKind::LocalBinding { id, depth } - if var.item.elem() == Elem::UInt(UIntKind::U32) => - { + VariableKind::LocalConst { id, depth } if var.item.elem() == Elem::UInt(UIntKind::U32) => { opt.program .int_ranges .get(&(id, depth, 0)) @@ -138,7 +136,7 @@ pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { upper_bound: None, }) } - VariableKind::LocalBinding { id, depth } => opt + VariableKind::LocalConst { id, depth } => opt .program .int_ranges .get(&(id, depth, 0)) @@ -166,7 +164,7 @@ pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { pub(crate) fn var_id(var: &Variable) -> Option<(u16, u8, u16)> { match var.kind { VariableKind::Versioned { id, depth, version } => Some((id, depth, version)), - VariableKind::LocalBinding { id, depth } => Some((id, depth, 0)), + VariableKind::LocalConst { id, depth } => Some((id, depth, 0)), _ => None, } } diff --git a/crates/cubecl-opt/src/version.rs b/crates/cubecl-opt/src/version.rs index 749dfcd88..7c9636eda 100644 --- a/crates/cubecl-opt/src/version.rs +++ b/crates/cubecl-opt/src/version.rs @@ -175,7 +175,7 @@ impl Optimizer { fn version_writes(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) { self.visit_out(&mut op.out, |_, var| match var.kind { - VariableKind::Local { id, depth } | VariableKind::Versioned { id, depth, .. } => { + VariableKind::LocalMut { id, depth } | VariableKind::Versioned { id, depth, .. } => { if let Some(version) = state.versions.get_mut(&(id, depth)) { let max_version = state.max_versions.get_mut(&(id, depth)).unwrap(); *max_version += 1; @@ -196,7 +196,7 @@ impl Optimizer { fn version_read(&self, var: &mut Variable, state: &mut SsaState<'_>) { match var.kind { - VariableKind::Local { id, depth } | VariableKind::Versioned { id, depth, .. } => { + VariableKind::LocalMut { id, depth } | VariableKind::Versioned { id, depth, .. } => { if self.program.variables.contains_key(&(id, depth)) { if let Some(version) = state.versions.get(&(id, depth)) { *var = Variable::new( diff --git a/crates/cubecl-spirv/src/compiler.rs b/crates/cubecl-spirv/src/compiler.rs index 32b1a308b..30e6de357 100644 --- a/crates/cubecl-spirv/src/compiler.rs +++ b/crates/cubecl-spirv/src/compiler.rs @@ -9,7 +9,7 @@ use std::{ }; use cubecl_core::{ - ir::{HybridAllocator, KernelDefinition, LocalAllocator}, + ir::{Allocator, KernelDefinition}, Compiler, ExecutionMode, }; use rspirv::{ @@ -159,8 +159,8 @@ impl Compiler for SpirvCompiler { elem.size() } - fn local_allocator() -> impl LocalAllocator { - HybridAllocator::default() + fn local_allocator() -> Allocator { + Allocator::new() } fn max_shared_memory_size() -> usize { diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 6d382fbbc..3e12b224f 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -378,7 +378,7 @@ impl SpirvCompiler { Variable::GlobalOutputArray(id, self.compile_item(item), pos) } core::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.elem), - core::VariableKind::Local { id, depth } => { + core::VariableKind::LocalMut { id, depth } => { let item = self.compile_item(item); let var = self.get_local((id, depth), &item); Variable::Local { id: var, item } @@ -388,7 +388,7 @@ impl SpirvCompiler { let id = (id, depth, version); Variable::Versioned { id, item } } - core::VariableKind::LocalBinding { id, depth } => { + core::VariableKind::LocalConst { id, depth } => { let item = self.compile_item(item); let id = (id, depth); Variable::LocalBinding { id, item } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs index fc91c7787..67e2a3817 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs @@ -7,14 +7,15 @@ pub enum Variable { GlobalOutputArray(u16, Item), GlobalScalar(u16, Elem, cube::Elem), ConstantScalar(ConstantScalarValue, Elem), - Local { + LocalMut { id: u16, item: Item, depth: u8, }, - LocalBinding { + LocalConst { id: u16, item: Item, + depth: u8, }, Named { name: String, @@ -98,8 +99,8 @@ impl Variable { Variable::SharedMemory(_, _, _) => false, Variable::ConstantArray(_, _, _) => false, Variable::LocalArray(_, _, _, _) => false, - Variable::Local { .. } => false, - Variable::LocalBinding { .. } => false, + Variable::LocalMut { .. } => false, + Variable::LocalConst { .. } => false, Variable::Named { .. } => false, Variable::Slice { .. } => false, Variable::WorkgroupIdX => true, @@ -132,7 +133,7 @@ impl Variable { Variable::GlobalInputArray(_, item) => item.elem().is_atomic(), Variable::GlobalOutputArray(_, item) => item.elem().is_atomic(), Variable::GlobalScalar(_, elem, _) => elem.is_atomic(), - Variable::Local { item, .. } => item.elem().is_atomic(), + Variable::LocalMut { item, .. } => item.elem().is_atomic(), Variable::Named { item, .. } => item.elem().is_atomic(), Variable::Slice { item, .. } => item.elem().is_atomic(), Variable::LocalScalar { elem, .. } => elem.is_atomic(), @@ -149,8 +150,8 @@ impl Variable { Self::SharedMemory(_, e, _) => *e, Self::ConstantArray(_, e, _) => *e, Self::LocalArray(_, e, _, _) => *e, - Self::Local { item, .. } => *item, - Self::LocalBinding { item, .. } => *item, + Self::LocalMut { item, .. } => *item, + Self::LocalConst { item, .. } => *item, Self::Slice { item, .. } => *item, Self::Named { item, .. } => *item, Self::ConstantScalar(_, e) => Item::Scalar(*e), @@ -279,12 +280,12 @@ impl Display for Variable { depth: scope_depth, .. } => write!(f, "s_{scope_depth}_{index}"), - Variable::Local { + Variable::LocalMut { id: index, depth: scope_depth, .. - } => write!(f, "l_{scope_depth}_{index}"), - Variable::LocalBinding { id: index, .. } => write!(f, "_{index}"), + } => write!(f, "l_mut_{scope_depth}_{index}"), + Variable::LocalConst { id, depth, .. } => write!(f, "l_{depth}_{id}"), Variable::Named { name, .. } => f.write_str(name), Variable::Slice { id: index, @@ -365,8 +366,8 @@ impl Display for IndexedVariable { impl Variable { pub fn fmt_left(&self) -> String { match self { - Variable::LocalBinding { id, .. } => { - format!("let _{id}") + Variable::LocalConst { .. } => { + format!("let {self}") } var => format!("{}", var), } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 58412e952..970bd19d1 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -10,7 +10,7 @@ use crate::{ use cubecl_core::ir::expand_checked_index_assign; use cubecl_core::{ - ir::{self as cube, HybridAllocator, UIntKind}, + ir::{self as cube, Allocator, UIntKind}, prelude::CompiledKernel, server::ComputeServer, Feature, Metadata, @@ -81,8 +81,8 @@ impl cubecl_core::Compiler for WgslCompiler { 32768 } - fn local_allocator() -> impl cube::LocalAllocator { - HybridAllocator::default() + fn local_allocator() -> Allocator { + Allocator::new() } } @@ -357,15 +357,16 @@ impl WgslCompiler { cube::VariableKind::GlobalScalar(id) => { wgsl::Variable::GlobalScalar(id, Self::compile_elem(item.elem), item.elem) } - cube::VariableKind::Local { id, depth } - | cube::VariableKind::Versioned { id, depth, .. } => wgsl::Variable::Local { + cube::VariableKind::LocalMut { id, depth } + | cube::VariableKind::Versioned { id, depth, .. } => wgsl::Variable::LocalMut { id, item: Self::compile_item(item), depth, }, - cube::VariableKind::LocalBinding { id, .. } => wgsl::Variable::LocalBinding { + cube::VariableKind::LocalConst { id, depth } => wgsl::Variable::LocalConst { id, item: Self::compile_item(item), + depth, }, cube::VariableKind::Slice { id, depth } => wgsl::Variable::Slice { id, @@ -998,9 +999,8 @@ impl WgslCompiler { cube::Operator::Slice(op) => { if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() { let input = op.input; - let input_len = - scope.create_local(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32))); - + let input_len = scope + .create_local_mut(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32))); instructions.extend(self.compile_scope(scope)); let length = match input.has_buffer_length() { diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index ccae8c944..9f3727f61 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -1082,8 +1082,8 @@ fn index( len: Option<&Variable>, ) -> core::fmt::Result { let is_scalar = match lhs { - Variable::Local { item, .. } => item.vectorization_factor() == 1, - Variable::LocalBinding { item, .. } => item.vectorization_factor() == 1, + Variable::LocalMut { item, .. } => item.vectorization_factor() == 1, + Variable::LocalConst { item, .. } => item.vectorization_factor() == 1, _ => false, }; diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index e6eea3e7c..283f33570 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -85,14 +85,14 @@ fn run(device: R::Device, strategy: matmul::Strategy) { } fn main() { - // #[cfg(feature = "wgpu")] - // { - // run::( - // Default::default(), - // matmul::Strategy::Tiling2D(Default::default()), - // ); - // run::(Default::default(), matmul::Strategy::PlaneMma); - // } + #[cfg(feature = "wgpu")] + { + run::( + Default::default(), + matmul::Strategy::Tiling2D(Default::default()), + ); + run::(Default::default(), matmul::Strategy::PlaneMma); + } #[cfg(feature = "wgpu-spirv")] {