Skip to content

Commit

Permalink
Simplify allocator (#388)
Browse files Browse the repository at this point in the history
* simplify the allocator implementation

* fix line assign operator

* uniformize function names

* fix reference kernel tests

* add documentation

* fix missing fmt_left when optimizing unary op

* use wgpu matmul bench

* run cargo fmt

* remove frontend test

* fix cpp plane operation so they can use an const local

* fix optimized plane operations for const

* run cargo fmt

* remove useless mut local var

* add method for const qualifier
  • Loading branch information
maxtremblay authored Jan 6, 2025
1 parent 4d6f50f commit 169ac37
Show file tree
Hide file tree
Showing 49 changed files with 489 additions and 478 deletions.
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Elem, KernelDefinition, LocalAllocator};
use crate::ir::{Allocator, Elem, KernelDefinition};
use cubecl_runtime::ExecutionMode;
use std::fmt::Display;

Expand All @@ -22,7 +22,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
fn local_allocator() -> impl LocalAllocator;
fn local_allocator() -> Allocator;
/// The maximal size of a shared memory, in bytes
fn max_shared_memory_size() -> usize;
}
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl KernelIntegrator {
});
self.expansion.scope.write_global(
Variable::new(
VariableKind::Local {
VariableKind::LocalMut {
id: local,

depth: self.expansion.scope.depth,
Expand All @@ -432,7 +432,7 @@ impl KernelIntegrator {
} => {
self.expansion.scope.write_global(
Variable::new(
VariableKind::Local {
VariableKind::LocalMut {
id: local,
depth: self.expansion.scope.depth,
},
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Elem, Item, LocalAllocator, ReusingAllocator, Visibility};
use crate::ir::{Allocator, Elem, Item, Visibility};
use crate::prelude::KernelDefinition;
use crate::KernelSettings;
use crate::{
Expand Down Expand Up @@ -117,7 +117,7 @@ impl KernelBuilder {
.integrate(settings)
}

pub fn with_local_allocator(allocator: impl LocalAllocator + 'static) -> Self {
pub fn with_local_allocator(allocator: Allocator) -> Self {
Self {
context: CubeContext::root(allocator),
inputs: Vec::new(),
Expand All @@ -131,6 +131,6 @@ impl KernelBuilder {

impl Default for KernelBuilder {
fn default() -> Self {
Self::with_local_allocator(ReusingAllocator::default())
Self::with_local_allocator(Allocator::new())
}
}
8 changes: 4 additions & 4 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl<I: Int> Iterable<I> for RangeExpand<I> {
) {
let mut child = context.child();
let index_ty = Item::new(I::as_elem(context));
let i = child.create_local_undeclared(index_ty);
let i = child.create_local_restricted(index_ty);

body(&mut child, i.clone().into());

Expand Down Expand Up @@ -131,7 +131,7 @@ impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
) {
let mut child = context.child();
let index_ty = Item::new(I::as_elem(context));
let i = child.create_local_undeclared(index_ty);
let i = child.create_local_restricted(index_ty);

body(&mut child, i.clone().into());

Expand Down Expand Up @@ -396,7 +396,7 @@ pub fn if_else_expr_expand<C: CubePrimitive>(
None => {
let mut then_child = context.child();
let ret = then_block(&mut then_child);
let out: ExpandElementTyped<C> = context.create_local_variable(ret.expand.item).into();
let out: ExpandElementTyped<C> = context.create_local_mut(ret.expand.item).into();
assign::expand(&mut then_child, ret, out.clone());

IfElseExprExpand::Runtime {
Expand Down Expand Up @@ -501,7 +501,7 @@ pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
) -> SwitchExpandExpr<I, C> {
let mut default_child = context.child();
let default = default_block(&mut default_child);
let out: ExpandElementTyped<C> = context.create_local_variable(default.expand.item).into();
let out: ExpandElementTyped<C> = context.create_local_mut(default.expand.item).into();
assign::expand(&mut default_child, default, out.clone());

SwitchExpandExpr {
Expand Down
10 changes: 5 additions & 5 deletions crates/cubecl-core/src/frontend/container/array/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ mod vectorization {
let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));

let new_var = if factor == 1 {
let new_var = context.create_local_binding(item);
let new_var = context.create_local(item);
let element = index::expand(
context,
self.clone(),
Expand All @@ -189,7 +189,7 @@ mod vectorization {
assign::expand(context, element, new_var.clone().into());
new_var
} else {
let new_var = context.create_local_variable(item);
let new_var = context.create_local_mut(item);
for i in 0..factor {
let expand: Self = self.expand.clone().into();
let element =
Expand Down Expand Up @@ -230,7 +230,7 @@ mod metadata {
impl<T: CubeType> ExpandElementTyped<Array<T>> {
// Expand method of [len](Array::len).
pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
let out = context.create_local_binding(Item::new(u32::as_elem(context)));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Metadata::Length {
var: self.expand.into(),
Expand All @@ -245,7 +245,7 @@ mod metadata {
self,
context: &mut CubeContext,
) -> ExpandElementTyped<u32> {
let out = context.create_local_binding(Item::new(u32::as_elem(context)));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Metadata::BufferLength {
var: self.expand.into(),
Expand Down Expand Up @@ -298,7 +298,7 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item);
let out = context.create_local(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/container/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl<T: SizedContainer> Iterable<T::Item> for ExpandElementTyped<T> {
let len: ExpandElement = T::len(&self.expand, context);

let mut child = context.child();
let i = child.create_local_undeclared(index_ty);
let i = child.create_local_restricted(index_ty);

let item = index::expand(&mut child, self, i.clone().into());
body(&mut child, item);
Expand Down
7 changes: 3 additions & 4 deletions crates/cubecl-core/src/frontend/container/line/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ mod fill {
value: ExpandElementTyped<P>,
) -> Self {
let length = self.expand.item.vectorization;
let output =
context.create_local_binding(Item::vectorized(P::as_elem(context), length));
let output = context.create_local(Item::vectorized(P::as_elem(context), length));

cast::expand::<P>(context, value, output.clone().into());

Expand Down Expand Up @@ -128,7 +127,7 @@ mod empty {
None => None,
};
context
.create_local_variable(Item::vectorized(Self::as_elem(context), length))
.create_local_mut(Item::vectorized(Self::as_elem(context), length))
.into()
}
}
Expand Down Expand Up @@ -216,7 +215,7 @@ macro_rules! impl_line_comparison {
let lhs = self.expand.into();
let rhs = rhs.expand.into();

let output = context.create_local_binding(Item::vectorized(bool::as_elem(context), size));
let output = context.create_local_mut(Item::vectorized(bool::as_elem(context), size));

context.register(Instruction::new(
Operator::$operator(BinaryOperator { lhs, rhs }),
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/container/shared_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item);
let out = context.create_local(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/container/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item);
let out = context.create_local(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
Expand All @@ -195,7 +195,7 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item);
let out = context.create_local(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
Expand Down
12 changes: 6 additions & 6 deletions crates/cubecl-core/src/frontend/container/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ mod metadata {
dim: ExpandElementTyped<u32>,
) -> ExpandElementTyped<u32> {
let dim: ExpandElement = dim.into();
let out = context.create_local_binding(Item::new(u32::as_elem(context)));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Metadata::Stride {
dim: *dim,
Expand All @@ -148,7 +148,7 @@ mod metadata {
dim: ExpandElementTyped<u32>,
) -> ExpandElementTyped<u32> {
let dim: ExpandElement = dim.into();
let out = context.create_local_binding(Item::new(u32::as_elem(context)));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Metadata::Shape {
dim: *dim,
Expand All @@ -171,7 +171,7 @@ mod metadata {
let shape = self.clone().__expand_shape_method(context, dim.clone());

// Compute `num_strides = index / stride`.
let num_strides = context.create_local_binding(Item::new(u32::as_elem(context)));
let num_strides = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Operator::Div(BinaryOperator {
lhs: *index,
Expand All @@ -181,7 +181,7 @@ mod metadata {
));

// Compute `coordinate = num_strides % shape `.
let coordinate = context.create_local_binding(Item::new(u32::as_elem(context)));
let coordinate = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Operator::Modulo(BinaryOperator {
lhs: *num_strides,
Expand Down Expand Up @@ -210,7 +210,7 @@ mod metadata {

// Expand method of [rank](Tensor::rank).
pub fn __expand_rank_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
let out = context.create_local_binding(Item::new(u32::as_elem(context)));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
out.into()
}
Expand Down Expand Up @@ -258,7 +258,7 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item);
let out = context.create_local(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
Expand Down
36 changes: 18 additions & 18 deletions crates/cubecl-core/src/frontend/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::ir::{self, Elem, Instruction, Item, ReusingAllocator, Scope, Variable, VariableKind};
use crate::{frontend::ExpandElement, ir::LocalAllocator};
use crate::frontend::ExpandElement;
use crate::ir::{self, Allocator, Elem, Instruction, Item, Scope, Variable, VariableKind};
use alloc::rc::Rc;
use core::cell::RefCell;
use cubecl_runtime::debug::DebugLogger;
Expand All @@ -9,14 +9,14 @@ use std::collections::HashMap;
pub struct CubeContext {
pub root: Rc<RefCell<Scope>>,
pub scope: Rc<RefCell<Scope>>,
pub local_allocator: Rc<dyn LocalAllocator>,
pub allocator: Allocator,
pub debug_enabled: bool,
pub typemap: Rc<RefCell<HashMap<TypeId, Elem>>>,
}

impl Default for CubeContext {
fn default() -> Self {
Self::root(ReusingAllocator::default())
Self::root(Allocator::new())
}
}

Expand All @@ -25,13 +25,13 @@ impl CubeContext {
/// A root scope is at the root of a compute shader
/// Therefore there is one cube context per shader
/// The allocator will define the strategy for creating local intermediates and mutable variables
pub fn root(allocator: impl LocalAllocator + 'static) -> CubeContext {
pub fn root(allocator: Allocator) -> CubeContext {
let root = Rc::new(RefCell::new(Scope::root()));
let typemap = Rc::new(RefCell::new(HashMap::new()));
let scope = root.clone();

Self {
local_allocator: Rc::new(allocator),
allocator,
scope,
root,
debug_enabled: DebugLogger::default().is_activated(),
Expand Down Expand Up @@ -64,7 +64,7 @@ impl CubeContext {
Self {
scope: Rc::new(RefCell::new(scope)),
root: self.root.clone(),
local_allocator: self.local_allocator.clone(),
allocator: self.allocator.clone(),
debug_enabled: self.debug_enabled,
typemap: self.typemap.clone(),
}
Expand All @@ -78,23 +78,23 @@ impl CubeContext {
.into_inner()
}

/// Create a new mutable local variable
pub fn create_local_variable(&mut self, item: Item) -> ExpandElement {
self.local_allocator
.create_local_variable(self.root.clone(), self.scope.clone(), item)
/// Create a new mutable local variable.
pub fn create_local_mut(&mut self, item: Item) -> ExpandElement {
self.allocator
.create_local_mut(&mut self.root.borrow_mut(), item)
}

/// Create a new immutable local binding
pub fn create_local_binding(&mut self, item: Item) -> ExpandElement {
self.local_allocator
.create_local_binding(self.root.clone(), self.scope.clone(), item)
/// Create a new immutable local variable.
pub fn create_local(&mut self, item: Item) -> ExpandElement {
self.allocator
.create_local(&mut self.scope.borrow_mut(), item)
}

/// Create a new immutable local binding that must never be a reused variable, regardless of
/// allocator
pub fn create_local_undeclared(&mut self, item: Item) -> ExpandElement {
self.local_allocator
.create_local_undeclared(self.root.clone(), self.scope.clone(), item)
pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement {
self.allocator
.create_local_restricted(&mut self.scope.borrow_mut(), item)
}

/// Create a new matrix element.
Expand Down
Loading

0 comments on commit 169ac37

Please sign in to comment.