Skip to content

Commit

Permalink
Merge branch 'main' into chore/optimizer-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 7, 2025
2 parents 8decc52 + 169ac37 commit 900f07f
Show file tree
Hide file tree
Showing 217 changed files with 4,821 additions and 11,574 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ embassy-futures = { version = "0.1.1" } # for no-std
futures-lite = { version = "2.3.0", default-features = false }

[profile.dev]
opt-level = 2
opt-level = 0
debug = 0 # Speed up compilation time and not necessary.
1 change: 1 addition & 0 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ log = { workspace = true }
num-traits = { workspace = true }
paste = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

[dev-dependencies]
pretty_assertions = { workspace = true }
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Elem, KernelDefinition, LocalAllocator};
use crate::ir::{Allocator, Elem, KernelDefinition};
use cubecl_runtime::ExecutionMode;
use std::fmt::Display;

Expand All @@ -22,7 +22,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
fn local_allocator() -> impl LocalAllocator;
fn local_allocator() -> Allocator;
/// The maximal size of a shared memory, in bytes
fn max_shared_memory_size() -> usize;
}
13 changes: 6 additions & 7 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ use std::num::NonZero;
use super::Compiler;
use crate::{
ir::{
Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, Variable,
VariableKind, Vectorization, Visibility,
Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, UIntKind,
Variable, VariableKind, Vectorization, Visibility,
},
prelude::CubePrimitive,
Runtime,
};

Expand Down Expand Up @@ -321,7 +320,7 @@ impl KernelIntegrator {
named.push((
"info".to_string(),
Binding {
item: Item::new(u32::as_elem()),
item: Item::new(Elem::UInt(UIntKind::U32)),
visibility: Visibility::Read,
location: Location::Storage,
has_extended_meta: false,
Expand Down Expand Up @@ -413,7 +412,7 @@ impl KernelIntegrator {
});
self.expansion.scope.write_global(
Variable::new(
VariableKind::Local {
VariableKind::LocalMut {
id: local,

depth: self.expansion.scope.depth,
Expand All @@ -433,7 +432,7 @@ impl KernelIntegrator {
} => {
self.expansion.scope.write_global(
Variable::new(
VariableKind::Local {
VariableKind::LocalMut {
id: local,
depth: self.expansion.scope.depth,
},
Expand Down Expand Up @@ -531,7 +530,7 @@ fn bool_item(ty: Item) -> Item {
pub fn bool_elem(elem: Elem) -> Elem {
match elem {
// U32 are used for bool tensors
Elem::Bool => u32::as_elem(),
Elem::Bool => Elem::UInt(UIntKind::U32),
_ => elem,
}
}
6 changes: 3 additions & 3 deletions crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Elem, Item, LocalAllocator, ReusingAllocator, Visibility};
use crate::ir::{Allocator, Elem, Item, Visibility};
use crate::prelude::KernelDefinition;
use crate::KernelSettings;
use crate::{
Expand Down Expand Up @@ -117,7 +117,7 @@ impl KernelBuilder {
.integrate(settings)
}

pub fn with_local_allocator(allocator: impl LocalAllocator + 'static) -> Self {
pub fn with_local_allocator(allocator: Allocator) -> Self {
Self {
context: CubeContext::root(allocator),
inputs: Vec::new(),
Expand All @@ -131,6 +131,6 @@ impl KernelBuilder {

impl Default for KernelBuilder {
fn default() -> Self {
Self::with_local_allocator(ReusingAllocator::default())
Self::with_local_allocator(Allocator::new())
}
}
12 changes: 6 additions & 6 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ impl<I: Int> Iterable<I> for RangeExpand<I> {
mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
) {
let mut child = context.child();
let index_ty = Item::new(I::as_elem());
let i = child.create_local_undeclared(index_ty);
let index_ty = Item::new(I::as_elem(context));
let i = child.create_local_restricted(index_ty);

body(&mut child, i.clone().into());

Expand Down Expand Up @@ -130,8 +130,8 @@ impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
) {
let mut child = context.child();
let index_ty = Item::new(I::as_elem());
let i = child.create_local_undeclared(index_ty);
let index_ty = Item::new(I::as_elem(context));
let i = child.create_local_restricted(index_ty);

body(&mut child, i.clone().into());

Expand Down Expand Up @@ -396,7 +396,7 @@ pub fn if_else_expr_expand<C: CubePrimitive>(
None => {
let mut then_child = context.child();
let ret = then_block(&mut then_child);
let out: ExpandElementTyped<C> = context.create_local_variable(ret.expand.item).into();
let out: ExpandElementTyped<C> = context.create_local_mut(ret.expand.item).into();
assign::expand(&mut then_child, ret, out.clone());

IfElseExprExpand::Runtime {
Expand Down Expand Up @@ -501,7 +501,7 @@ pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
) -> SwitchExpandExpr<I, C> {
let mut default_child = context.child();
let default = default_block(&mut default_child);
let out: ExpandElementTyped<C> = context.create_local_variable(default.expand.item).into();
let out: ExpandElementTyped<C> = context.create_local_mut(default.expand.item).into();
assign::expand(&mut default_child, default, out.clone());

SwitchExpandExpr {
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-core/src/frontend/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,13 @@ impl<C: CubePrimitive> Matrix<C> {
k: ExpandElementTyped<u32>,
layout: MatrixLayout,
) -> MatrixExpand<C> {
let elem = C::as_elem(context);
let elem = context.create_matrix(ir::Matrix {
ident,
m: m.constant().unwrap().as_u32() as u8,
n: n.constant().unwrap().as_u32() as u8,
k: k.constant().unwrap().as_u32() as u8,
elem: C::as_elem(),
elem,
layout,
});
MatrixExpand {
Expand Down Expand Up @@ -436,12 +437,13 @@ pub mod cast {
_ => unreachable!(),
};

let elem = O::as_elem(context);
let elem = context.create_matrix(ir::Matrix {
ident,
m: input_mat.m,
n: input_mat.n,
k: input_mat.k,
elem: O::as_elem(),
elem,
layout: MatrixLayout::Undefined,
});

Expand Down
34 changes: 20 additions & 14 deletions crates/cubecl-core/src/frontend/container/array/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,16 @@ mod new {
.constant()
.expect("Array need constant initialization value")
.as_u32();
context
.create_local_array(Item::new(T::as_elem()), size)
.into()
let elem = T::as_elem(context);
context.create_local_array(Item::new(elem), size).into()
}

/// Expand function of [from_data](Array::from_data).
pub fn __expand_from_data<C: CubePrimitive>(
context: &mut CubeContext,
data: ArrayData<C>,
) -> <Self as CubeType>::ExpandType {
let var = context.create_const_array(Item::new(T::as_elem()), data.values);
let var = context.create_const_array(Item::new(T::as_elem(context)), data.values);
ExpandElementTyped::new(var)
}
}
Expand Down Expand Up @@ -157,7 +156,10 @@ mod vectorization {
};
context
.create_local_array(
Item::vectorized(T::as_elem(), NonZero::new(vectorization_factor as u8)),
Item::vectorized(
T::as_elem(context),
NonZero::new(vectorization_factor as u8),
),
size,
)
.into()
Expand All @@ -178,20 +180,24 @@ mod vectorization {
let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));

let new_var = if factor == 1 {
let new_var = context.create_local_binding(item);
let element =
index::expand(context, self.clone(), ExpandElementTyped::from_lit(0u32));
let new_var = context.create_local(item);
let element = index::expand(
context,
self.clone(),
ExpandElementTyped::from_lit(context, 0u32),
);
assign::expand(context, element, new_var.clone().into());
new_var
} else {
let new_var = context.create_local_variable(item);
let new_var = context.create_local_mut(item);
for i in 0..factor {
let expand: Self = self.expand.clone().into();
let element = index::expand(context, expand, ExpandElementTyped::from_lit(i));
let element =
index::expand(context, expand, ExpandElementTyped::from_lit(context, i));
index_assign::expand::<Array<C>>(
context,
new_var.clone().into(),
ExpandElementTyped::from_lit(i),
ExpandElementTyped::from_lit(context, i),
element,
);
}
Expand Down Expand Up @@ -224,7 +230,7 @@ mod metadata {
impl<T: CubeType> ExpandElementTyped<Array<T>> {
// Expand method of [len](Array::len).
pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
let out = context.create_local_binding(Item::new(u32::as_elem()));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Metadata::Length {
var: self.expand.into(),
Expand All @@ -239,7 +245,7 @@ mod metadata {
self,
context: &mut CubeContext,
) -> ExpandElementTyped<u32> {
let out = context.create_local_binding(Item::new(u32::as_elem()));
let out = context.create_local(Item::new(u32::as_elem(context)));
context.register(Instruction::new(
Metadata::BufferLength {
var: self.expand.into(),
Expand Down Expand Up @@ -292,7 +298,7 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item);
let out = context.create_local(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
Expand Down
25 changes: 20 additions & 5 deletions crates/cubecl-core/src/frontend/container/array/launch.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
use std::{marker::PhantomData, num::NonZero};

use serde::{Deserialize, Serialize};

use crate::{
compute::{KernelBuilder, KernelLauncher},
ir::{Item, Vectorization},
prelude::{
ArgSettings, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand, TensorHandleRef,
ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
TensorHandleRef,
},
Runtime,
};

use super::Array;

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub struct ArrayCompilationArg {
pub inplace: Option<u16>,
pub vectorisation: Vectorization,
}

impl CompilationArg for ArrayCompilationArg {}

/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
pub struct ArrayHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle,
Expand All @@ -33,7 +38,10 @@ impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
builder: &mut KernelBuilder,
) -> ExpandElementTyped<Array<C>> {
builder
.input_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.input_array(Item::vectorized(
C::as_elem(&builder.context),
arg.vectorisation,
))
.into()
}
fn expand_output(
Expand All @@ -43,7 +51,10 @@ impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder
.output_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.output_array(Item::vectorized(
C::as_elem(&builder.context),
arg.vectorisation,
))
.into(),
}
}
Expand Down Expand Up @@ -82,7 +93,11 @@ impl<'a, R: Runtime> ArrayArg<'a, R> {
vectorization_factor: u8,
) -> Self {
ArrayArg::Handle {
handle: ArrayHandleRef::from_raw_parts(handle, length, E::as_elem().size()),
handle: ArrayHandleRef::from_raw_parts(
handle,
length,
E::size().expect("Element should have a size"),
),
vectorization_factor,
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/container/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ impl<T: SizedContainer> Iterable<T::Item> for ExpandElementTyped<T> {
context: &mut CubeContext,
mut body: impl FnMut(&mut CubeContext, <T::Item as CubeType>::ExpandType),
) {
let index_ty = Item::new(u32::as_elem());
let index_ty = Item::new(u32::as_elem(context));
let len: ExpandElement = T::len(&self.expand, context);

let mut child = context.child();
let i = child.create_local_undeclared(index_ty);
let i = child.create_local_restricted(index_ty);

let item = index::expand(&mut child, self, i.clone().into());
body(&mut child, item);
Expand Down
Loading

0 comments on commit 900f07f

Please sign in to comment.