Skip to content

Commit

Permalink
Extract variable allocation strategy to trait to allow backends to us…
Browse files Browse the repository at this point in the history
…e more optimal strategies (#139)
  • Loading branch information
wingertge authored Sep 23, 2024
1 parent 5371cfc commit 0f77be9
Show file tree
Hide file tree
Showing 61 changed files with 1,093 additions and 896 deletions.
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Elem, KernelDefinition};
use crate::ir::{Elem, KernelDefinition, LocalAllocator};
use cubecl_runtime::ExecutionMode;
use std::fmt::Display;

Expand All @@ -17,6 +17,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
fn local_allocator() -> impl LocalAllocator;
/// The maximal size of a shared memory, in bytes
fn max_shared_memory_size() -> usize;
}
14 changes: 9 additions & 5 deletions crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Elem, Item, Visibility};
use crate::ir::{Elem, Item, LocalAllocator, ReusingAllocator, Visibility};
use crate::prelude::KernelDefinition;
use crate::KernelSettings;
use crate::{
Expand Down Expand Up @@ -104,12 +104,10 @@ impl KernelBuilder {
})
.integrate(settings)
}
}

impl Default for KernelBuilder {
fn default() -> Self {
pub fn with_local_allocator(allocator: impl LocalAllocator + 'static) -> Self {
Self {
context: CubeContext::root(),
context: CubeContext::root(allocator),
inputs: Vec::new(),
outputs: Vec::new(),
indices: HashMap::new(),
Expand All @@ -118,3 +116,9 @@ impl Default for KernelBuilder {
}
}
}

impl Default for KernelBuilder {
fn default() -> Self {
Self::with_local_allocator(ReusingAllocator::default())
}
}
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ pub fn if_else_expr_expand<C: CubePrimitive>(
None => {
let mut then_child = context.child();
let ret = then_block(&mut then_child);
let out: ExpandElementTyped<C> = context.create_local(ret.expand.item()).into();
let out: ExpandElementTyped<C> =
context.create_local_variable(ret.expand.item()).into();
assign::expand(&mut then_child, ret, out.clone());

IfElseExprExpand::Runtime {
Expand Down
90 changes: 22 additions & 68 deletions crates/cubecl-core/src/frontend/context.rs
Original file line number Diff line number Diff line change
@@ -1,68 +1,31 @@
use crate::frontend::ExpandElement;
use crate::ir::{self, Elem, Item, Operation, Scope};
use crate::ir::{self, Elem, Item, Operation, ReusingAllocator, Scope};
use crate::{frontend::ExpandElement, ir::LocalAllocator};
use alloc::rc::Rc;
use core::cell::RefCell;
use std::collections::HashMap;

#[derive(Default, Clone)]
pub struct VariablePool {
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
}

impl VariablePool {
/// Returns an old, not used anymore variable, if there exists one.
pub fn reuse(&self, item: Item) -> Option<ExpandElement> {
let map = self.map.borrow();

// Filter for candidate variables of the same Item
let variables = map.get(&item)?;

// Among the candidates, take a variable if it's only referenced by the map
// Arbitrarily takes the first it finds in reverse order.
for variable in variables.iter().rev() {
match variable {
ExpandElement::Managed(var) => {
if Rc::strong_count(var) == 1 {
return Some(variable.clone());
}
}
ExpandElement::Plain(_) => (),
}
}

// If no candidate was found, a new var will be needed
None
}

/// Insert a new variable in the map, which is classified by Item
pub fn insert(&mut self, var: ExpandElement) {
let mut map = self.map.borrow_mut();
let item = var.item();

if let Some(variables) = map.get_mut(&item) {
variables.push(var.clone());
} else {
map.insert(var.item(), vec![var.clone()]);
}
}
}

pub struct CubeContext {
pub root: Rc<RefCell<Scope>>,
pub scope: Rc<RefCell<Scope>>,
pub pool: VariablePool,
pub local_allocator: Rc<dyn LocalAllocator>,
}

impl Default for CubeContext {
fn default() -> Self {
Self::root(ReusingAllocator::default())
}
}

impl CubeContext {
/// Create a new cube context, with a root scope
/// A root scope is at the root of a compute shader
/// Therefore there is one cube context per shader
pub fn root() -> CubeContext {
/// The allocator will define the strategy for creating local intermediates and mutable variables
pub fn root(allocator: impl LocalAllocator + 'static) -> CubeContext {
let root = Rc::new(RefCell::new(Scope::root()));
let scope = root.clone();

Self {
pool: Default::default(),
local_allocator: Rc::new(allocator),
scope,
root,
}
Expand All @@ -78,7 +41,7 @@ impl CubeContext {
Self {
scope: Rc::new(RefCell::new(scope)),
root: self.root.clone(),
pool: self.pool.clone(),
local_allocator: self.local_allocator.clone(),
}
}

Expand All @@ -90,25 +53,16 @@ impl CubeContext {
.into_inner()
}

/// When a new variable is required, we check if we can reuse an old one
/// Otherwise we create a new one.
pub fn create_local(&mut self, item: Item) -> ExpandElement {
if item.elem.is_atomic() {
let new = self.scope.borrow_mut().create_local_undeclared(item);
return ExpandElement::Plain(new);
}

// Reuse an old variable if possible
if let Some(var) = self.pool.reuse(item) {
return var;
}

// Create a new variable at the root scope
// Insert it in the variable pool for potential reuse
let new = ExpandElement::Managed(Rc::new(self.root.borrow_mut().create_local(item)));
self.pool.insert(new.clone());
/// Create a new mutable local variable
pub fn create_local_variable(&mut self, item: Item) -> ExpandElement {
self.local_allocator
.create_local_variable(self.root.clone(), self.scope.clone(), item)
}

new
/// Create a new immutable local binding
pub fn create_local_binding(&mut self, item: Item) -> ExpandElement {
self.local_allocator
.create_local_binding(self.root.clone(), self.scope.clone(), item)
}

/// Create a new matrix element.
Expand Down
13 changes: 7 additions & 6 deletions crates/cubecl-core/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
.expect("Vectorization must be comptime")
.as_u32();
let var = self.expand.clone();
let new_var = context.create_local(Item::vectorized(
var.item().elem(),
NonZero::new(factor as u8),
));
let item = Item::vectorized(var.item().elem(), NonZero::new(factor as u8));

if factor == 1 {
let new_var = if factor == 1 {
let new_var = context.create_local_binding(item);
let element = index::expand(context, self.clone(), ExpandElementTyped::from_lit(0u32));
assign::expand(context, element, new_var.clone().into());
new_var
} else {
let new_var = context.create_local_variable(item);
for i in 0..factor {
let expand: Self = self.expand.clone().into();
let element = index::expand(context, expand, ExpandElementTyped::from_lit(i));
Expand All @@ -107,7 +107,8 @@ impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
element,
);
}
}
new_var
};
new_var.into()
}
}
Expand Down
20 changes: 10 additions & 10 deletions crates/cubecl-core/src/frontend/element/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ where
pointer: <Self as CubeType>::ExpandType,
) -> <Self::Primitive as CubeType>::ExpandType {
let pointer: ExpandElement = pointer.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicLoad(UnaryOperator {
input: *pointer,
out: *new_var,
Expand Down Expand Up @@ -127,7 +127,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicSwap(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -145,7 +145,7 @@ where
let pointer: ExpandElement = pointer.into();
let cmp: ExpandElement = cmp.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicCompareAndSwap(CompareAndSwapOperator {
out: *new_var,
input: *pointer,
Expand All @@ -162,7 +162,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicAdd(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -178,7 +178,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicSub(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -194,7 +194,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicMax(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -210,7 +210,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicMin(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -226,7 +226,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicAnd(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -242,7 +242,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicOr(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand All @@ -258,7 +258,7 @@ where
) -> <Self::Primitive as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = context.create_local(Item::new(Self::Primitive::as_elem()));
let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem()));
context.register(Operator::AtomicXor(BinaryOperator {
lhs: *ptr,
rhs: *value,
Expand Down
29 changes: 12 additions & 17 deletions crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{CubePrimitive, Numeric, Vectorized};
use crate::{
ir::{ConstantScalarValue, Elem, FloatKind, Item, Operator, Variable},
prelude::{index_assign, init_expand, CubeContext, CubeIndex, KernelBuilder, KernelLauncher},
prelude::{assign, init_expand, CubeContext, CubeIndex, KernelBuilder, KernelLauncher},
Runtime,
};
use alloc::rc::Rc;
Expand Down Expand Up @@ -329,6 +329,11 @@ impl ExpandElement {
ExpandElement::Plain(_) => false,
}
}

/// Explicitly consume the element, freeing it for reuse if no other copies exist.
pub fn consume(self) -> Variable {
*self
}
}

impl core::ops::Deref for ExpandElement {
Expand Down Expand Up @@ -369,6 +374,7 @@ pub(crate) fn init_expand_element<E: Into<ExpandElement>>(
Variable::LocalScalar { .. } => init(elem),
Variable::ConstantScalar { .. } => init(elem),
Variable::Local { .. } => init(elem),
Variable::LocalBinding { .. } => init(elem),
// Constant should be initialized since the new variable can be mutated afterward.
// And it is assumed those values are cloned.
Variable::Rank
Expand Down Expand Up @@ -445,25 +451,14 @@ pub(crate) fn __expand_vectorized<C: Numeric + CubeIndex<u32>, Out: Numeric>(
vectorization: u32,
elem: Elem,
) -> ExpandElementTyped<Out> {
let new_var = context.create_local(Item::vectorized(elem, NonZero::new(vectorization as u8)));
let new_var =
context.create_local_binding(Item::vectorized(elem, NonZero::new(vectorization as u8)));
let val = Out::from(val).unwrap();
let val: ExpandElementTyped<Out> = val.into();

// Allow setting explicit vectorization of 1 without trying to index assign it
if vectorization == 1 {
return val;
}

for (i, element) in vec![val; vectorization as usize].iter().enumerate() {
let element = elem.from_constant(*element.expand);

index_assign::expand::<C>(
context,
new_var.clone().into(),
ExpandElementTyped::from_lit(i),
ExpandElement::Plain(element).into(),
);
}
// Explanation for removing all this code: Assignments are already being unrolled and broadcast
// in the backend, so this was just duplicating code and it interfered with the SSA allocator
assign::expand(context, val, new_var.clone().into());

new_var.into()
}
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/element/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub trait Cast: CubePrimitive {
context: &mut CubeContext,
value: ExpandElementTyped<From>,
) -> <Self as CubeType>::ExpandType {
let new_var = context.create_local(Item::vectorized(
let new_var = context.create_local_binding(Item::vectorized(
<Self as CubePrimitive>::as_elem(),
value.expand.item().vectorization,
));
Expand Down Expand Up @@ -45,7 +45,7 @@ pub trait BitCast: CubePrimitive {
) -> <Self as CubeType>::ExpandType {
let value: ExpandElement = value.into();
let var: Variable = *value;
let new_var = context.create_local(Item::vectorized(
let new_var = context.create_local_binding(Item::vectorized(
<Self as CubePrimitive>::as_elem(),
var.item().vectorization,
));
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ macro_rules! impl_float {
vectorization: u32,
) -> <Self as CubeType>::ExpandType {
context
.create_local(Item::vectorized(
.create_local_variable(Item::vectorized(
Self::as_elem(),
NonZero::new(vectorization as u8),
))
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/element/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub trait Numeric:
context: &mut CubeContext,
vec: [u32; D],
) -> <Self as CubeType>::ExpandType {
let new_var = context.create_local(Item::vectorized(
let new_var = context.create_local_binding(Item::vectorized(
Self::as_elem(),
NonZero::new(vec.len() as u8),
));
Expand Down
Loading

0 comments on commit 0f77be9

Please sign in to comment.