diff --git a/Cargo.toml b/Cargo.toml index 46107fcb7..ce5a6ae53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,5 +75,5 @@ embassy-futures = { version = "0.1.1" } # for no-std futures-lite = { version = "2.3.0", default-features = false } [profile.dev] -opt-level = 2 +opt-level = 0 debug = 0 # Speed up compilation time and not necessary. diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index 81fa0445f..8069becfb 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -39,6 +39,7 @@ log = { workspace = true } num-traits = { workspace = true } paste = { workspace = true } serde = { workspace = true } +serde_json = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/crates/cubecl-core/src/codegen/compiler.rs b/crates/cubecl-core/src/codegen/compiler.rs index 756c8ac69..be26dad60 100644 --- a/crates/cubecl-core/src/codegen/compiler.rs +++ b/crates/cubecl-core/src/codegen/compiler.rs @@ -1,4 +1,4 @@ -use crate::ir::{Elem, KernelDefinition, LocalAllocator}; +use crate::ir::{Allocator, Elem, KernelDefinition}; use cubecl_runtime::ExecutionMode; use std::fmt::Display; @@ -22,7 +22,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { ) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: Elem) -> usize; - fn local_allocator() -> impl LocalAllocator; + fn local_allocator() -> Allocator; /// The maximal size of a shared memory, in bytes fn max_shared_memory_size() -> usize; } diff --git a/crates/cubecl-core/src/codegen/integrator.rs b/crates/cubecl-core/src/codegen/integrator.rs index e3cbb999f..866e1d3d9 100644 --- a/crates/cubecl-core/src/codegen/integrator.rs +++ b/crates/cubecl-core/src/codegen/integrator.rs @@ -3,10 +3,9 @@ use std::num::NonZero; use super::Compiler; use crate::{ ir::{ - Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, Variable, - VariableKind, Vectorization, Visibility, + Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, UIntKind, + Variable, VariableKind, Vectorization, Visibility, }, - prelude::CubePrimitive, Runtime, }; @@ -321,7 +320,7 @@ impl KernelIntegrator { named.push(( "info".to_string(), Binding { - item: Item::new(u32::as_elem()), + item: Item::new(Elem::UInt(UIntKind::U32)), visibility: Visibility::Read, location: Location::Storage, has_extended_meta: false, @@ -413,7 +412,7 @@ impl KernelIntegrator { }); self.expansion.scope.write_global( Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: local, depth: self.expansion.scope.depth, @@ -433,7 +432,7 @@ impl KernelIntegrator { } => { self.expansion.scope.write_global( Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: local, depth: self.expansion.scope.depth, }, @@ -531,7 +530,7 @@ fn bool_item(ty: Item) -> Item { pub fn bool_elem(elem: Elem) -> Elem { match elem { // U32 are used for bool tensors - Elem::Bool => u32::as_elem(), + Elem::Bool => Elem::UInt(UIntKind::U32), _ => elem, } } diff --git a/crates/cubecl-core/src/compute/builder.rs b/crates/cubecl-core/src/compute/builder.rs index 6cb0ed7f2..156c4af8d 100644 --- a/crates/cubecl-core/src/compute/builder.rs +++ b/crates/cubecl-core/src/compute/builder.rs @@ -1,4 +1,4 @@ -use crate::ir::{Elem, Item, LocalAllocator, ReusingAllocator, Visibility}; +use crate::ir::{Allocator, Elem, Item, Visibility}; use crate::prelude::KernelDefinition; use crate::KernelSettings; use crate::{ @@ -117,7 +117,7 @@ impl KernelBuilder { .integrate(settings) } - pub fn with_local_allocator(allocator: impl LocalAllocator + 'static) -> Self { + pub fn with_local_allocator(allocator: Allocator) -> Self { Self { context: CubeContext::root(allocator), inputs: Vec::new(), @@ -131,6 +131,6 @@ impl KernelBuilder { impl Default for KernelBuilder { fn default() -> Self { - Self::with_local_allocator(ReusingAllocator::default()) + Self::with_local_allocator(Allocator::new()) } } diff --git a/crates/cubecl-core/src/frontend/branch.rs b/crates/cubecl-core/src/frontend/branch.rs index c317fb36b..8c1095766 100644 --- a/crates/cubecl-core/src/frontend/branch.rs +++ b/crates/cubecl-core/src/frontend/branch.rs @@ -100,8 +100,8 @@ impl Iterable for RangeExpand { mut body: impl FnMut(&mut CubeContext, ::ExpandType), ) { let mut child = context.child(); - let index_ty = Item::new(I::as_elem()); - let i = child.create_local_undeclared(index_ty); + let index_ty = Item::new(I::as_elem(context)); + let i = child.create_local_restricted(index_ty); body(&mut child, i.clone().into()); @@ -130,8 +130,8 @@ impl> Iterable for SteppedRangeExpand { mut body: impl FnMut(&mut CubeContext, ::ExpandType), ) { let mut child = context.child(); - let index_ty = Item::new(I::as_elem()); - let i = child.create_local_undeclared(index_ty); + let index_ty = Item::new(I::as_elem(context)); + let i = child.create_local_restricted(index_ty); body(&mut child, i.clone().into()); @@ -396,7 +396,7 @@ pub fn if_else_expr_expand( None => { let mut then_child = context.child(); let ret = then_block(&mut then_child); - let out: ExpandElementTyped = context.create_local_variable(ret.expand.item).into(); + let out: ExpandElementTyped = context.create_local_mut(ret.expand.item).into(); assign::expand(&mut then_child, ret, out.clone()); IfElseExprExpand::Runtime { @@ -501,7 +501,7 @@ pub fn switch_expand_expr( ) -> SwitchExpandExpr { let mut default_child = context.child(); let default = default_block(&mut default_child); - let out: ExpandElementTyped = context.create_local_variable(default.expand.item).into(); + let out: ExpandElementTyped = context.create_local_mut(default.expand.item).into(); assign::expand(&mut default_child, default, out.clone()); SwitchExpandExpr { diff --git a/crates/cubecl-core/src/frontend/cmma.rs b/crates/cubecl-core/src/frontend/cmma.rs index a8e44d889..7d95a475d 100644 --- a/crates/cubecl-core/src/frontend/cmma.rs +++ b/crates/cubecl-core/src/frontend/cmma.rs @@ -190,12 +190,13 @@ impl Matrix { k: ExpandElementTyped, layout: MatrixLayout, ) -> MatrixExpand { + let elem = C::as_elem(context); let elem = context.create_matrix(ir::Matrix { ident, m: m.constant().unwrap().as_u32() as u8, n: n.constant().unwrap().as_u32() as u8, k: k.constant().unwrap().as_u32() as u8, - elem: C::as_elem(), + elem, layout, }); MatrixExpand { @@ -436,12 +437,13 @@ pub mod cast { _ => unreachable!(), }; + let elem = O::as_elem(context); let elem = context.create_matrix(ir::Matrix { ident, m: input_mat.m, n: input_mat.n, k: input_mat.k, - elem: O::as_elem(), + elem, layout: MatrixLayout::Undefined, }); diff --git a/crates/cubecl-core/src/frontend/container/array/base.rs b/crates/cubecl-core/src/frontend/container/array/base.rs index ba19f1922..7a075723b 100644 --- a/crates/cubecl-core/src/frontend/container/array/base.rs +++ b/crates/cubecl-core/src/frontend/container/array/base.rs @@ -45,9 +45,8 @@ mod new { .constant() .expect("Array need constant initialization value") .as_u32(); - context - .create_local_array(Item::new(T::as_elem()), size) - .into() + let elem = T::as_elem(context); + context.create_local_array(Item::new(elem), size).into() } /// Expand function of [from_data](Array::from_data). @@ -55,7 +54,7 @@ mod new { context: &mut CubeContext, data: ArrayData, ) -> ::ExpandType { - let var = context.create_const_array(Item::new(T::as_elem()), data.values); + let var = context.create_const_array(Item::new(T::as_elem(context)), data.values); ExpandElementTyped::new(var) } } @@ -157,7 +156,10 @@ mod vectorization { }; context .create_local_array( - Item::vectorized(T::as_elem(), NonZero::new(vectorization_factor as u8)), + Item::vectorized( + T::as_elem(context), + NonZero::new(vectorization_factor as u8), + ), size, ) .into() @@ -178,20 +180,24 @@ mod vectorization { let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8)); let new_var = if factor == 1 { - let new_var = context.create_local_binding(item); - let element = - index::expand(context, self.clone(), ExpandElementTyped::from_lit(0u32)); + let new_var = context.create_local(item); + let element = index::expand( + context, + self.clone(), + ExpandElementTyped::from_lit(context, 0u32), + ); assign::expand(context, element, new_var.clone().into()); new_var } else { - let new_var = context.create_local_variable(item); + let new_var = context.create_local_mut(item); for i in 0..factor { let expand: Self = self.expand.clone().into(); - let element = index::expand(context, expand, ExpandElementTyped::from_lit(i)); + let element = + index::expand(context, expand, ExpandElementTyped::from_lit(context, i)); index_assign::expand::>( context, new_var.clone().into(), - ExpandElementTyped::from_lit(i), + ExpandElementTyped::from_lit(context, i), element, ); } @@ -224,7 +230,7 @@ mod metadata { impl ExpandElementTyped> { // Expand method of [len](Array::len). pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped { - let out = context.create_local_binding(Item::new(u32::as_elem())); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::Length { var: self.expand.into(), @@ -239,7 +245,7 @@ mod metadata { self, context: &mut CubeContext, ) -> ExpandElementTyped { - let out = context.create_local_binding(Item::new(u32::as_elem())); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::BufferLength { var: self.expand.into(), @@ -292,7 +298,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/array/launch.rs b/crates/cubecl-core/src/frontend/container/array/launch.rs index 7e95d1ba8..589e5670c 100644 --- a/crates/cubecl-core/src/frontend/container/array/launch.rs +++ b/crates/cubecl-core/src/frontend/container/array/launch.rs @@ -1,22 +1,27 @@ use std::{marker::PhantomData, num::NonZero}; +use serde::{Deserialize, Serialize}; + use crate::{ compute::{KernelBuilder, KernelLauncher}, ir::{Item, Vectorization}, prelude::{ - ArgSettings, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand, TensorHandleRef, + ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand, + TensorHandleRef, }, Runtime, }; use super::Array; -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub struct ArrayCompilationArg { pub inplace: Option, pub vectorisation: Vectorization, } +impl CompilationArg for ArrayCompilationArg {} + /// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle). pub struct ArrayHandleRef<'a, R: Runtime> { pub handle: &'a cubecl_runtime::server::Handle, @@ -33,7 +38,10 @@ impl LaunchArgExpand for Array { builder: &mut KernelBuilder, ) -> ExpandElementTyped> { builder - .input_array(Item::vectorized(C::as_elem(), arg.vectorisation)) + .input_array(Item::vectorized( + C::as_elem(&builder.context), + arg.vectorisation, + )) .into() } fn expand_output( @@ -43,7 +51,10 @@ impl LaunchArgExpand for Array { match arg.inplace { Some(id) => builder.inplace_output(id).into(), None => builder - .output_array(Item::vectorized(C::as_elem(), arg.vectorisation)) + .output_array(Item::vectorized( + C::as_elem(&builder.context), + arg.vectorisation, + )) .into(), } } @@ -82,7 +93,11 @@ impl<'a, R: Runtime> ArrayArg<'a, R> { vectorization_factor: u8, ) -> Self { ArrayArg::Handle { - handle: ArrayHandleRef::from_raw_parts(handle, length, E::as_elem().size()), + handle: ArrayHandleRef::from_raw_parts( + handle, + length, + E::size().expect("Element should have a size"), + ), vectorization_factor, } } diff --git a/crates/cubecl-core/src/frontend/container/iter.rs b/crates/cubecl-core/src/frontend/container/iter.rs index 1e603cbf4..f53d2074a 100644 --- a/crates/cubecl-core/src/frontend/container/iter.rs +++ b/crates/cubecl-core/src/frontend/container/iter.rs @@ -27,11 +27,11 @@ impl Iterable for ExpandElementTyped { context: &mut CubeContext, mut body: impl FnMut(&mut CubeContext, ::ExpandType), ) { - let index_ty = Item::new(u32::as_elem()); + let index_ty = Item::new(u32::as_elem(context)); let len: ExpandElement = T::len(&self.expand, context); let mut child = context.child(); - let i = child.create_local_undeclared(index_ty); + let i = child.create_local_restricted(index_ty); let item = index::expand(&mut child, self, i.clone().into()); body(&mut child, item); diff --git a/crates/cubecl-core/src/frontend/container/line/base.rs b/crates/cubecl-core/src/frontend/container/line/base.rs index 01710ad71..f2fd51675 100644 --- a/crates/cubecl-core/src/frontend/container/line/base.rs +++ b/crates/cubecl-core/src/frontend/container/line/base.rs @@ -1,7 +1,7 @@ use std::num::NonZero; use crate::{ - ir::{BinaryOperator, ConstantScalarValue, Instruction, Item, Operator}, + ir::{BinaryOperator, ConstantScalarValue, Elem, Instruction, Item, Operator}, prelude::{binary_expand_fixed_output, CubeContext, Dot, ExpandElement, Numeric}, unexpanded, }; @@ -11,12 +11,19 @@ use crate::frontend::{ }; /// A contiguous list of elements that supports auto-vectorized operations. -#[derive(Clone, Copy, Eq)] -pub struct Line { +pub struct Line

{ // Comptime lines only support 1 element. pub(crate) val: P, } +impl Clone for Line

{ + fn clone(&self) -> Self { + *self + } +} +impl Eq for Line

{} +impl Copy for Line

{} + /// Module that contains the implementation details of the new function. mod new { use super::*; @@ -79,7 +86,7 @@ mod fill { value: ExpandElementTyped

, ) -> Self { let length = self.expand.item.vectorization; - let output = context.create_local_binding(Item::vectorized(P::as_elem(), length)); + let output = context.create_local(Item::vectorized(P::as_elem(context), length)); cast::expand::

(context, value, output.clone().into()); @@ -120,7 +127,7 @@ mod empty { None => None, }; context - .create_local_variable(Item::vectorized(Self::as_elem(), length)) + .create_local_mut(Item::vectorized(Self::as_elem(context), length)) .into() } } @@ -208,7 +215,7 @@ macro_rules! impl_line_comparison { let lhs = self.expand.into(); let rhs = rhs.expand.into(); - let output = context.create_local_binding(Item::vectorized(bool::as_elem(), size)); + let output = context.create_local_mut(Item::vectorized(bool::as_elem(context), size)); context.register(Instruction::new( Operator::$operator(BinaryOperator { lhs, rhs }), @@ -251,8 +258,16 @@ impl IntoRuntime for Line

{ } impl CubePrimitive for Line

{ - fn as_elem() -> crate::ir::Elem { - P::as_elem() + fn as_elem(context: &CubeContext) -> Elem { + P::as_elem(context) + } + + fn as_elem_native() -> Option { + P::as_elem_native() + } + + fn size() -> Option { + P::size() } } diff --git a/crates/cubecl-core/src/frontend/container/line/ops.rs b/crates/cubecl-core/src/frontend/container/line/ops.rs index 2fe03aa76..603346abe 100644 --- a/crates/cubecl-core/src/frontend/container/line/ops.rs +++ b/crates/cubecl-core/src/frontend/container/line/ops.rs @@ -6,6 +6,7 @@ use crate::{ ExpandElementTyped, Floor, Log, Log1p, Max, Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh, }, + prelude::{CountOnes, ReverseBits}, unexpanded, }; @@ -259,6 +260,8 @@ impl Remainder for Line

{} impl Round for Line

{} impl Floor for Line

{} impl Ceil for Line

{} +impl CountOnes for Line

{} +impl ReverseBits for Line

{} impl NumCast for Line

{ fn from(n: T) -> Option { diff --git a/crates/cubecl-core/src/frontend/container/sequence/launch.rs b/crates/cubecl-core/src/frontend/container/sequence/launch.rs index 6f8cb42c7..d83748262 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/launch.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/launch.rs @@ -1,8 +1,10 @@ use std::{cell::RefCell, rc::Rc}; +use serde::{Deserialize, Serialize}; + use crate::{ compute::KernelBuilder, - prelude::{ArgSettings, LaunchArg, LaunchArgExpand}, + prelude::{ArgSettings, CompilationArg, LaunchArg, LaunchArgExpand}, Runtime, }; @@ -27,10 +29,13 @@ impl<'a, R: Runtime, T: LaunchArg> SequenceArg<'a, R, T> { } } +#[derive(Serialize, Deserialize)] pub struct SequenceCompilationArg { pub values: Vec, } +impl CompilationArg for SequenceCompilationArg {} + impl Clone for SequenceCompilationArg { fn clone(&self) -> Self { Self { diff --git a/crates/cubecl-core/src/frontend/container/shared_memory.rs b/crates/cubecl-core/src/frontend/container/shared_memory.rs index 32b4395d8..07372eec5 100644 --- a/crates/cubecl-core/src/frontend/container/shared_memory.rs +++ b/crates/cubecl-core/src/frontend/container/shared_memory.rs @@ -49,7 +49,10 @@ impl SharedMemory { .expect("Shared memory need constant initialization value") .as_u32(); let var = context.create_shared( - Item::vectorized(T::as_elem(), NonZero::new(vectorization_factor as u8)), + Item::vectorized( + T::as_elem(context), + NonZero::new(vectorization_factor as u8), + ), size, ); ExpandElementTyped::new(var) @@ -68,7 +71,10 @@ impl SharedMemory { .expect("Shared memory need constant initialization value") .as_u32(); let var = context.create_shared( - Item::vectorized(T::as_elem(), NonZero::new(vectorization_factor as u8)), + Item::vectorized( + T::as_elem(context), + NonZero::new(vectorization_factor as u8), + ), size, ); ExpandElementTyped::new(var) @@ -82,7 +88,7 @@ impl SharedMemory { .constant() .expect("Shared memory need constant initialization value") .as_u32(); - let var = context.create_shared(Item::new(T::as_elem()), size); + let var = context.create_shared(Item::new(T::as_elem(context)), size); ExpandElementTyped::new(var) } } @@ -129,7 +135,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/slice.rs b/crates/cubecl-core/src/frontend/container/slice.rs index 7badaad4d..33ac040c4 100644 --- a/crates/cubecl-core/src/frontend/container/slice.rs +++ b/crates/cubecl-core/src/frontend/container/slice.rs @@ -177,7 +177,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, @@ -195,7 +195,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/tensor/base.rs b/crates/cubecl-core/src/frontend/container/tensor/base.rs index 1417b13e9..409883dd0 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/base.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/base.rs @@ -130,7 +130,7 @@ mod metadata { dim: ExpandElementTyped, ) -> ExpandElementTyped { let dim: ExpandElement = dim.into(); - let out = context.create_local_binding(Item::new(u32::as_elem())); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::Stride { dim: *dim, @@ -148,7 +148,7 @@ mod metadata { dim: ExpandElementTyped, ) -> ExpandElementTyped { let dim: ExpandElement = dim.into(); - let out = context.create_local_binding(Item::new(u32::as_elem())); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Metadata::Shape { dim: *dim, @@ -171,7 +171,7 @@ mod metadata { let shape = self.clone().__expand_shape_method(context, dim.clone()); // Compute `num_strides = index / stride`. - let num_strides = context.create_local_binding(Item::new(u32::as_elem())); + let num_strides = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Operator::Div(BinaryOperator { lhs: *index, @@ -181,7 +181,7 @@ mod metadata { )); // Compute `coordinate = num_strides % shape `. - let coordinate = context.create_local_binding(Item::new(u32::as_elem())); + let coordinate = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new( Operator::Modulo(BinaryOperator { lhs: *num_strides, @@ -210,7 +210,7 @@ mod metadata { // Expand method of [rank](Tensor::rank). pub fn __expand_rank_method(self, context: &mut CubeContext) -> ExpandElementTyped { - let out = context.create_local_binding(Item::new(u32::as_elem())); + let out = context.create_local(Item::new(u32::as_elem(context))); context.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out)); out.into() } @@ -258,7 +258,7 @@ mod indexation { context: &mut CubeContext, i: ExpandElementTyped, ) -> ExpandElementTyped { - let out = context.create_local_binding(self.expand.item); + let out = context.create_local(self.expand.item); context.register(Instruction::new( Operator::UncheckedIndex(BinaryOperator { lhs: *self.expand, diff --git a/crates/cubecl-core/src/frontend/container/tensor/launch.rs b/crates/cubecl-core/src/frontend/container/tensor/launch.rs index 272dd8286..4474e1f52 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/launch.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/launch.rs @@ -1,9 +1,13 @@ use std::{marker::PhantomData, num::NonZero}; +use serde::{Deserialize, Serialize}; + use crate::{ compute::{KernelBuilder, KernelLauncher}, ir::{Item, Vectorization}, - prelude::{ArgSettings, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand}, + prelude::{ + ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand, + }, Runtime, }; @@ -53,12 +57,14 @@ impl core::fmt::Debug for TensorHandleRef<'_, R> { } /// Compilation argument for a [tensor](Tensor). -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub struct TensorCompilationArg { pub inplace: Option, pub vectorisation: Vectorization, } +impl CompilationArg for TensorCompilationArg {} + impl LaunchArgExpand for Tensor { type CompilationArg = TensorCompilationArg; @@ -67,7 +73,10 @@ impl LaunchArgExpand for Tensor { builder: &mut KernelBuilder, ) -> ExpandElementTyped> { builder - .input_tensor(Item::vectorized(C::as_elem(), arg.vectorisation)) + .input_tensor(Item::vectorized( + C::as_elem(&builder.context), + arg.vectorisation, + )) .into() } fn expand_output( @@ -77,7 +86,10 @@ impl LaunchArgExpand for Tensor { match arg.inplace { Some(id) => builder.inplace_output(id).into(), None => builder - .output_tensor(Item::vectorized(C::as_elem(), arg.vectorisation)) + .output_tensor(Item::vectorized( + C::as_elem(&builder.context), + arg.vectorisation, + )) .into(), } } @@ -122,7 +134,7 @@ impl<'a, R: Runtime> TensorArg<'a, R> { handle, strides, shape, - E::as_elem().size(), + E::size().expect("Element should have a size"), ), vectorization_factor: factor, } diff --git a/crates/cubecl-core/src/frontend/context.rs b/crates/cubecl-core/src/frontend/context.rs index 76c99d18b..0abb6b031 100644 --- a/crates/cubecl-core/src/frontend/context.rs +++ b/crates/cubecl-core/src/frontend/context.rs @@ -1,19 +1,22 @@ -use crate::ir::{self, Elem, Instruction, Item, ReusingAllocator, Scope, Variable, VariableKind}; -use crate::{frontend::ExpandElement, ir::LocalAllocator}; +use crate::frontend::ExpandElement; +use crate::ir::{self, Allocator, Elem, Instruction, Item, Scope, Variable, VariableKind}; use alloc::rc::Rc; use core::cell::RefCell; use cubecl_runtime::debug::DebugLogger; +use std::any::TypeId; +use std::collections::HashMap; pub struct CubeContext { pub root: Rc>, pub scope: Rc>, - pub local_allocator: Rc, + pub allocator: Allocator, pub debug_enabled: bool, + pub typemap: Rc>>, } impl Default for CubeContext { fn default() -> Self { - Self::root(ReusingAllocator::default()) + Self::root(Allocator::new()) } } @@ -22,15 +25,17 @@ impl CubeContext { /// A root scope is at the root of a compute shader /// Therefore there is one cube context per shader /// The allocator will define the strategy for creating local intermediates and mutable variables - pub fn root(allocator: impl LocalAllocator + 'static) -> CubeContext { + pub fn root(allocator: Allocator) -> CubeContext { let root = Rc::new(RefCell::new(Scope::root())); + let typemap = Rc::new(RefCell::new(HashMap::new())); let scope = root.clone(); Self { - local_allocator: Rc::new(allocator), + allocator, scope, root, debug_enabled: DebugLogger::default().is_activated(), + typemap, } } @@ -38,14 +43,30 @@ impl CubeContext { self.scope.borrow_mut().register(op) } + /// Resolve the element type of the given generic type. + pub fn resolve_elem(&self) -> Option { + let map = self.typemap.borrow(); + let result = map.get(&TypeId::of::()); + + result.cloned() + } + + /// Register the element type for the given generic type. + pub fn register_elem(&mut self, elem: Elem) { + let mut map = self.typemap.borrow_mut(); + + map.insert(TypeId::of::(), elem); + } + pub fn child(&mut self) -> CubeContext { let scope = self.scope.borrow_mut().child(); Self { scope: Rc::new(RefCell::new(scope)), root: self.root.clone(), - local_allocator: self.local_allocator.clone(), + allocator: self.allocator.clone(), debug_enabled: self.debug_enabled, + typemap: self.typemap.clone(), } } @@ -57,23 +78,23 @@ impl CubeContext { .into_inner() } - /// Create a new mutable local variable - pub fn create_local_variable(&mut self, item: Item) -> ExpandElement { - self.local_allocator - .create_local_variable(self.root.clone(), self.scope.clone(), item) + /// Create a new mutable local variable. + pub fn create_local_mut(&mut self, item: Item) -> ExpandElement { + self.allocator + .create_local_mut(&mut self.root.borrow_mut(), item) } - /// Create a new immutable local binding - pub fn create_local_binding(&mut self, item: Item) -> ExpandElement { - self.local_allocator - .create_local_binding(self.root.clone(), self.scope.clone(), item) + /// Create a new immutable local variable. + pub fn create_local(&mut self, item: Item) -> ExpandElement { + self.allocator + .create_local(&mut self.scope.borrow_mut(), item) } /// Create a new immutable local binding that must never be a reused variable, regardless of /// allocator - pub fn create_local_undeclared(&mut self, item: Item) -> ExpandElement { - self.local_allocator - .create_local_undeclared(self.root.clone(), self.scope.clone(), item) + pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement { + self.allocator + .create_local_restricted(&mut self.scope.borrow_mut(), item) } /// Create a new matrix element. diff --git a/crates/cubecl-core/src/frontend/element/atomic.rs b/crates/cubecl-core/src/frontend/element/atomic.rs index 9edab4408..8fafa2a58 100644 --- a/crates/cubecl-core/src/frontend/element/atomic.rs +++ b/crates/cubecl-core/src/frontend/element/atomic.rs @@ -143,7 +143,7 @@ where pointer: ::ExpandType, ) -> ::ExpandType { let pointer: ExpandElement = pointer.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Load(UnaryOperator { input: *pointer }), *new_var, @@ -171,7 +171,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Swap(BinaryOperator { lhs: *ptr, @@ -191,7 +191,7 @@ where let pointer: ExpandElement = pointer.into(); let cmp: ExpandElement = cmp.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::CompareAndSwap(CompareAndSwapOperator { input: *pointer, @@ -210,7 +210,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Add(BinaryOperator { lhs: *ptr, @@ -228,7 +228,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Sub(BinaryOperator { lhs: *ptr, @@ -246,7 +246,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Max(BinaryOperator { lhs: *ptr, @@ -264,7 +264,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Min(BinaryOperator { lhs: *ptr, @@ -282,7 +282,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::And(BinaryOperator { lhs: *ptr, @@ -300,7 +300,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Or(BinaryOperator { lhs: *ptr, @@ -318,7 +318,7 @@ where ) -> ::ExpandType { let ptr: ExpandElement = pointer.into(); let value: ExpandElement = value.into(); - let new_var = context.create_local_binding(Item::new(Self::Primitive::as_elem())); + let new_var = context.create_local(Item::new(Self::Primitive::as_elem(context))); context.register(Instruction::new( AtomicOp::Xor(BinaryOperator { lhs: *ptr, @@ -353,8 +353,8 @@ macro_rules! impl_atomic_int { } impl CubePrimitive for $type { - fn as_elem() -> Elem { - Elem::AtomicInt(IntKind::$inner_type) + fn as_elem_native() -> Option { + Some(Elem::AtomicInt(IntKind::$inner_type)) } } @@ -399,8 +399,8 @@ impl CubeType for AtomicU32 { } impl CubePrimitive for AtomicU32 { - fn as_elem() -> Elem { - Elem::AtomicUInt(UIntKind::U32) + fn as_elem_native() -> Option { + Some(Elem::AtomicUInt(UIntKind::U32)) } } diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index a6a1cc4d4..b7fbc8584 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -1,5 +1,4 @@ -use super::{flex32, CubePrimitive, Numeric}; -use crate::tf32; +use super::{flex32, tf32, CubePrimitive, Numeric}; use crate::{ ir::{ConstantScalarValue, Operation, Variable, VariableKind}, prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher}, @@ -50,6 +49,35 @@ pub trait Init: Sized { fn init(self, context: &mut CubeContext) -> Self; } +/// Argument used during the compilation of kernels. +pub trait CompilationArg: + serde::Serialize + + serde::de::DeserializeOwned + + Clone + + PartialEq + + Eq + + core::hash::Hash + + core::fmt::Debug + + Send + + Sync + + 'static +{ + /// Compilation args should be the same even with different element types. However, it isn't + /// possible to enforce it with the type system. So, we make the compilation args serializable + /// and dynamically cast them. + /// + /// Without this, the compilation time is unreasonable. The performance drop isn't a concern + /// since this is only done once when compiling a kernel for the first time. + fn dynamic_cast(&self) -> Arg { + let val = serde_json::to_string(self).unwrap(); + + serde_json::from_str(&val) + .expect("Compilation argument should be the same even with different element types") + } +} + +impl CompilationArg for () {} + /// Defines how a [launch argument](LaunchArg) can be expanded. /// /// Normally this type should be implemented two times for an argument. @@ -59,14 +87,7 @@ pub trait Init: Sized { #[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")] pub trait LaunchArgExpand: CubeType { /// Compilation argument. - type CompilationArg: Clone - + PartialEq - + Eq - + core::hash::Hash - + core::fmt::Debug - + Send - + Sync - + 'static; + type CompilationArg: CompilationArg; /// Register an input variable during compilation that fill the [KernelBuilder]. fn expand( @@ -285,9 +306,9 @@ impl From> for ExpandElement { impl ExpandElementTyped { /// Create an [ExpandElementTyped] from a value that is normally a literal. - pub fn from_lit>(lit: L) -> Self { + pub fn from_lit>(context: &CubeContext, lit: L) -> Self { let variable: Variable = lit.into(); - let variable = T::as_elem().from_constant(variable); + let variable = T::as_elem(context).from_constant(variable); ExpandElementTyped::new(ExpandElement::Plain(variable)) } @@ -306,7 +327,7 @@ impl ExpandElement { pub fn can_mut(&self) -> bool { match self { ExpandElement::Managed(var) => { - if let VariableKind::Local { .. } = var.as_ref().kind { + if let VariableKind::LocalMut { .. } = var.as_ref().kind { Rc::strong_count(var) <= 2 } else { false @@ -358,9 +379,9 @@ pub(crate) fn init_expand_element>( match elem.kind { VariableKind::GlobalScalar { .. } => init(elem), VariableKind::ConstantScalar { .. } => init(elem), - VariableKind::Local { .. } => init(elem), + VariableKind::LocalMut { .. } => init(elem), VariableKind::Versioned { .. } => init(elem), - VariableKind::LocalBinding { .. } => init(elem), + VariableKind::LocalConst { .. } => init(elem), // Constant should be initialized since the new variable can be mutated afterward. // And it is assumed those values are cloned. VariableKind::Builtin(_) => init(elem), @@ -403,9 +424,9 @@ impl Init for Vec { /// Create a constant element of the correct type during expansion. pub(crate) fn __expand_new( - _context: &mut CubeContext, + context: &mut CubeContext, val: C, ) -> ExpandElementTyped { - let val = Out::from(val).unwrap(); - val.into() + let input: ExpandElementTyped = val.into(); + ::__expand_cast_from(context, input) } diff --git a/crates/cubecl-core/src/frontend/element/bool.rs b/crates/cubecl-core/src/frontend/element/bool.rs index 56ef95688..45407cfc2 100644 --- a/crates/cubecl-core/src/frontend/element/bool.rs +++ b/crates/cubecl-core/src/frontend/element/bool.rs @@ -28,8 +28,8 @@ impl CubeType for bool { } impl CubePrimitive for bool { - fn as_elem() -> Elem { - Elem::Bool + fn as_elem_native() -> Option { + Some(Elem::Bool) } } diff --git a/crates/cubecl-core/src/frontend/element/cast.rs b/crates/cubecl-core/src/frontend/element/cast.rs index f8278ce1d..2c7366276 100644 --- a/crates/cubecl-core/src/frontend/element/cast.rs +++ b/crates/cubecl-core/src/frontend/element/cast.rs @@ -18,9 +18,8 @@ pub trait Cast: CubePrimitive { if core::any::TypeId::of::() == core::any::TypeId::of::() { return value.expand.into(); } - - let new_var = context.create_local_binding(Item::vectorized( - ::as_elem(), + let new_var = context.create_local(Item::vectorized( + ::as_elem(context), value.expand.item.vectorization, )); cast::expand(context, value, new_var.clone().into()); @@ -49,8 +48,8 @@ pub trait BitCast: CubePrimitive { ) -> ::ExpandType { let value: ExpandElement = value.into(); let var: Variable = *value; - let new_var = context.create_local_binding(Item::vectorized( - ::as_elem(), + let new_var = context.create_local(Item::vectorized( + ::as_elem(context), var.item.vectorization, )); context.register(Instruction::new( diff --git a/crates/cubecl-core/src/frontend/element/cube_elem.rs b/crates/cubecl-core/src/frontend/element/cube_elem.rs index 75a3c13dd..5faa4ccf4 100644 --- a/crates/cubecl-core/src/frontend/element/cube_elem.rs +++ b/crates/cubecl-core/src/frontend/element/cube_elem.rs @@ -2,6 +2,7 @@ use half::{bf16, f16}; use crate::frontend::{CubeType, ExpandElement}; use crate::ir::{Elem, Variable}; +use crate::prelude::CubeContext; use super::{flex32, tf32, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime}; @@ -18,8 +19,25 @@ pub trait CubePrimitive: + Clone + Copy { - /// Return the element type to use on GPU - fn as_elem() -> Elem; + /// Return the element type to use on GPU. + fn as_elem(_context: &CubeContext) -> Elem { + Self::as_elem_native().expect("To be overriden if not native") + } + + /// Native or static element type. + fn as_elem_native() -> Option { + None + } + + /// Native or static element type. + fn as_elem_native_unchecked() -> Elem { + Self::as_elem_native().expect("To be a native type") + } + + /// Only native element types have a size. + fn size() -> Option { + Self::as_elem_native().map(|t| t.size()) + } fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType { ExpandElementTyped::new(elem) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 2f302fef8..99a2e8231 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -9,9 +9,11 @@ use super::Numeric; mod relaxed; mod tensor_float; +mod typemap; pub use relaxed::*; pub use tensor_float::*; +pub use typemap::*; /// Floating point numbers. Used as input in float kernels pub trait Float: @@ -33,6 +35,7 @@ pub trait Float: + Normalize + Dot + Into + + core::ops::Neg + core::ops::Add + core::ops::Sub + core::ops::Mul @@ -77,8 +80,8 @@ macro_rules! impl_float { impl CubePrimitive for $primitive { /// Return the element type to use on GPU - fn as_elem() -> Elem { - Elem::Float(FloatKind::$kind) + fn as_elem_native() -> Option { + Some(Elem::Float(FloatKind::$kind)) } } @@ -93,8 +96,12 @@ macro_rules! impl_float { } impl Numeric for $primitive { - const MAX: Self = $primitive::MAX; - const MIN: Self = $primitive::MIN; + fn min_value() -> Self { + ::min_value() + } + fn max_value() -> Self { + ::max_value() + } } impl ExpandElementBaseInit for $primitive { @@ -129,7 +136,7 @@ macro_rules! impl_float { _: &Self::CompilationArg, builder: &mut KernelBuilder, ) -> ExpandElementTyped { - builder.scalar($primitive::as_elem()).into() + builder.scalar($primitive::as_elem(&builder.context)).into() } } }; diff --git a/crates/cubecl-core/src/frontend/element/float/relaxed.rs b/crates/cubecl-core/src/frontend/element/float/relaxed.rs index bf5e9dd84..54f9b6bde 100644 --- a/crates/cubecl-core/src/frontend/element/float/relaxed.rs +++ b/crates/cubecl-core/src/frontend/element/float/relaxed.rs @@ -163,8 +163,8 @@ impl CubeType for flex32 { impl CubePrimitive for flex32 { /// Return the element type to use on GPU - fn as_elem() -> Elem { - Elem::Float(FloatKind::Flex32) + fn as_elem_native() -> Option { + Some(Elem::Float(FloatKind::Flex32)) } } @@ -176,8 +176,12 @@ impl IntoRuntime for flex32 { } impl Numeric for flex32 { - const MAX: Self = flex32::from_f32(f32::MAX); - const MIN: Self = flex32::from_f32(f32::MIN); + fn min_value() -> Self { + ::min_value() + } + fn max_value() -> Self { + ::max_value() + } } impl ExpandElementBaseInit for flex32 { @@ -222,7 +226,7 @@ impl LaunchArgExpand for flex32 { type CompilationArg = (); fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped { - builder.scalar(flex32::as_elem()).into() + builder.scalar(flex32::as_elem(&builder.context)).into() } } @@ -250,7 +254,7 @@ impl num_traits::Float for flex32 { } fn min_value() -> Self { - flex32(f32::min_value()) + flex32(::min_value()) } fn min_positive_value() -> Self { @@ -258,7 +262,7 @@ impl num_traits::Float for flex32 { } fn max_value() -> Self { - flex32(f32::max_value()) + flex32(::max_value()) } fn is_nan(self) -> bool { diff --git a/crates/cubecl-core/src/frontend/element/float/tensor_float.rs b/crates/cubecl-core/src/frontend/element/float/tensor_float.rs index cdc0cd14c..0b14d3edc 100644 --- a/crates/cubecl-core/src/frontend/element/float/tensor_float.rs +++ b/crates/cubecl-core/src/frontend/element/float/tensor_float.rs @@ -2,7 +2,7 @@ #![allow(clippy::transmute_float_to_int)] // prev=1.83. use bytemuck::{Pod, Zeroable}; -use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use half::f16; use num_traits::{NumCast, ToPrimitive}; use serde::Serialize; @@ -88,6 +88,14 @@ impl tf32 { } } +impl Neg for tf32 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::from_f32(self.to_f32().neg()) + } +} + impl Mul for tf32 { type Output = Self; @@ -174,8 +182,8 @@ impl CubeType for tf32 { impl CubePrimitive for tf32 { /// Return the element type to use on GPU - fn as_elem() -> Elem { - Elem::Float(FloatKind::TF32) + fn as_elem_native() -> Option { + Some(Elem::Float(FloatKind::TF32)) } } @@ -187,8 +195,12 @@ impl IntoRuntime for tf32 { } impl Numeric for tf32 { - const MAX: Self = tf32::from_f32(f32::MAX); - const MIN: Self = tf32::from_f32(f32::MIN); + fn min_value() -> Self { + Self(f32::MIN) + } + fn max_value() -> Self { + Self(f32::MAX) + } } impl ExpandElementBaseInit for tf32 { @@ -240,6 +252,6 @@ impl LaunchArgExpand for tf32 { type CompilationArg = (); fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped { - builder.scalar(tf32::as_elem()).into() + builder.scalar(tf32::as_elem(&builder.context)).into() } } diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs new file mode 100644 index 000000000..5ab39e1e7 --- /dev/null +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -0,0 +1,545 @@ +//! This module contains a configurable [element type](FloatExpand) for floats to be used during +//! kernel expansion to speed up Rust compilation. +//! +//! Expand functions don't need to be generated for different element types even if they are generic +//! over one, since the only use of numeric element types is to map to the [elem IR enum](Elem). +//! +//! This can be done dynamically using the context instead, reducing the binary size and the +//! compilation time of kernels significantly. +//! +//! You can still have multiple element types in a single kernel, since [FloatExpand] uses const +//! generics to differentiate between float kinds. + +use core::f32; +use std::{ + cmp::Ordering, + ops::{Div, DivAssign, Mul, MulAssign, Rem, RemAssign}, +}; + +use bytemuck::{Pod, Zeroable}; +use derive_more::derive::{ + Add, AddAssign, Display, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, +}; +use num_traits::{Num, NumCast, One, ToPrimitive, Zero}; +use serde::Serialize; + +use crate::{ + ir::{Elem, FloatKind, Variable}, + prelude::Numeric, +}; + +use super::{ + init_expand_element, Abs, Ceil, Clamp, Cos, CubeContext, CubeIndex, CubeIndexMut, + CubePrimitive, CubeType, Dot, Erf, Exp, ExpandElement, ExpandElementBaseInit, + ExpandElementTyped, Float, Floor, Index, Init, IntoRuntime, KernelBuilder, KernelLauncher, + LaunchArgExpand, Log, Log1p, Magnitude, Max, Min, Normalize, Powf, Recip, Remainder, Round, + Runtime, ScalarArgSettings, Sin, Sqrt, Tanh, +}; + +#[allow(non_camel_case_types)] +#[repr(transparent)] +#[derive( + Clone, + Copy, + Default, + Serialize, + Zeroable, + Pod, + PartialEq, + PartialOrd, + Neg, + Add, + Sub, + Mul, + Div, + Rem, + AddAssign, + SubAssign, + MulAssign, + DivAssign, + RemAssign, + Debug, + Display, +)] +pub struct FloatExpand(f32); +pub type NumericExpand = FloatExpand; +pub type IntExpand = FloatExpand; + +impl FloatExpand { + pub const MIN_POSITIVE: Self = Self(half::f16::MIN_POSITIVE.to_f32_const()); + + pub const fn from_f32(val: f32) -> Self { + FloatExpand(val) + } + + pub const fn from_f64(val: f64) -> Self { + FloatExpand(val as f32) + } + + pub const fn to_f32(self) -> f32 { + self.0 + } + + pub const fn to_f64(self) -> f64 { + self.0 as f64 + } + + pub fn total_cmp(&self, other: &Self) -> Ordering { + self.0.total_cmp(&other.0) + } + + pub fn is_nan(&self) -> bool { + self.0.is_nan() + } +} + +impl Mul for FloatExpand { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + FloatExpand(self.0 * rhs.0) + } +} + +impl Div for FloatExpand { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + FloatExpand(self.0 / rhs.0) + } +} + +impl Rem for FloatExpand { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + FloatExpand(self.0 % rhs.0) + } +} + +impl MulAssign for FloatExpand { + fn mul_assign(&mut self, rhs: Self) { + self.0 *= rhs.0; + } +} + +impl DivAssign for FloatExpand { + fn div_assign(&mut self, rhs: Self) { + self.0 /= rhs.0; + } +} + +impl RemAssign for FloatExpand { + fn rem_assign(&mut self, rhs: Self) { + self.0 %= rhs.0; + } +} + +impl From for FloatExpand { + fn from(value: f32) -> Self { + Self::from_f32(value) + } +} + +impl From> for f32 { + fn from(val: FloatExpand) -> Self { + val.to_f32() + } +} + +impl ToPrimitive for FloatExpand { + fn to_i64(&self) -> Option { + Some((*self).to_f32() as i64) + } + + fn to_u64(&self) -> Option { + Some((*self).to_f32() as u64) + } + + fn to_f32(&self) -> Option { + Some((*self).to_f32()) + } + + fn to_f64(&self) -> Option { + Some((*self).to_f32() as f64) + } +} + +impl NumCast for FloatExpand { + fn from(n: T) -> Option { + Some(FloatExpand::from_f32(n.to_f32()?)) + } +} + +impl CubeType for FloatExpand { + type ExpandType = ExpandElementTyped>; +} + +impl CubePrimitive for FloatExpand { + /// Return the element type to use on GPU + fn as_elem(context: &CubeContext) -> Elem { + context + .resolve_elem::() + .expect("Type to be registered") + } +} + +impl From> for Variable { + fn from(val: FloatExpand) -> Self { + // TODO: Fix how we create literal. + Variable::new( + crate::ir::VariableKind::ConstantScalar(crate::ir::ConstantScalarValue::Float( + val.0 as f64, + FloatKind::F32, + )), + crate::ir::Item::new(Elem::Float(FloatKind::F32)), + ) + } +} + +impl From> for ExpandElementTyped> { + fn from(value: FloatExpand) -> Self { + let var: Variable = value.into(); + ExpandElementTyped::new(ExpandElement::Plain(var)) + } +} + +impl IntoRuntime for FloatExpand { + fn __expand_runtime_method(self, context: &mut CubeContext) -> ExpandElementTyped { + let expand: ExpandElementTyped = ExpandElementTyped::from_lit(context, self); + Init::init(expand, context) + } +} + +impl Numeric for FloatExpand { + fn min_value() -> Self { + panic!("Can't use min value in comptime with dynamic element type"); + } + fn max_value() -> Self { + panic!("Can't use max value in comptime with dynamic element type"); + } +} + +impl ExpandElementBaseInit for FloatExpand { + fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { + init_expand_element(context, elem) + } +} + +impl Normalize for FloatExpand {} +impl Dot for FloatExpand {} +impl Magnitude for FloatExpand {} +impl Recip for FloatExpand {} +impl Erf for FloatExpand {} +impl Exp for FloatExpand {} +impl Remainder for FloatExpand {} +impl Abs for FloatExpand {} +impl Max for FloatExpand {} +impl Min for FloatExpand {} +impl Clamp for FloatExpand {} +impl Log for FloatExpand {} +impl Log1p for FloatExpand {} +impl Cos for FloatExpand {} +impl Sin for FloatExpand {} +impl Tanh for FloatExpand {} +impl Powf for FloatExpand {} +impl Sqrt for FloatExpand {} +impl Round for FloatExpand {} +impl Floor for FloatExpand {} +impl Ceil for FloatExpand {} + +impl CubeIndex for FloatExpand { + type Output = Self; +} +impl CubeIndexMut for FloatExpand {} + +impl Float for FloatExpand { + const DIGITS: u32 = 32; + + const EPSILON: Self = FloatExpand::from_f32(half::f16::EPSILON.to_f32_const()); + + const INFINITY: Self = FloatExpand::from_f32(f32::INFINITY); + + const MANTISSA_DIGITS: u32 = f32::MANTISSA_DIGITS; + + /// Maximum possible [`tf32`] power of 10 exponent + const MAX_10_EXP: i32 = f32::MAX_10_EXP; + /// Maximum possible [`tf32`] power of 2 exponent + const MAX_EXP: i32 = f32::MAX_EXP; + + /// Minimum possible normal [`tf32`] power of 10 exponent + const MIN_10_EXP: i32 = f32::MIN_10_EXP; + /// One greater than the minimum possible normal [`v`] power of 2 exponent + const MIN_EXP: i32 = f32::MIN_EXP; + + const MIN_POSITIVE: Self = FloatExpand(f32::MIN_POSITIVE); + + const NAN: Self = FloatExpand::from_f32(f32::NAN); + + const NEG_INFINITY: Self = FloatExpand::from_f32(f32::NEG_INFINITY); + + const RADIX: u32 = 2; + + fn new(val: f32) -> Self { + FloatExpand::from_f32(val) + } +} + +impl LaunchArgExpand for FloatExpand { + type CompilationArg = (); + + fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped { + builder + .scalar(FloatExpand::::as_elem(&builder.context)) + .into() + } +} + +impl ScalarArgSettings for FloatExpand { + fn register(&self, settings: &mut KernelLauncher) { + settings.register_f32(self.0); + } +} + +impl num_traits::Float for FloatExpand { + fn nan() -> Self { + FloatExpand(f32::nan()) + } + + fn infinity() -> Self { + FloatExpand(f32::infinity()) + } + + fn neg_infinity() -> Self { + FloatExpand(f32::neg_infinity()) + } + + fn neg_zero() -> Self { + FloatExpand(f32::neg_zero()) + } + + fn min_value() -> Self { + FloatExpand(::min_value()) + } + + fn min_positive_value() -> Self { + FloatExpand(f32::min_positive_value()) + } + + fn max_value() -> Self { + FloatExpand(::max_value()) + } + + fn is_nan(self) -> bool { + self.0.is_nan() + } + + fn is_infinite(self) -> bool { + self.0.is_infinite() + } + + fn is_finite(self) -> bool { + self.0.is_finite() + } + + fn is_normal(self) -> bool { + self.0.is_normal() + } + + fn classify(self) -> std::num::FpCategory { + self.0.classify() + } + + fn floor(self) -> Self { + FloatExpand(self.0.floor()) + } + + fn ceil(self) -> Self { + FloatExpand(self.0.ceil()) + } + + fn round(self) -> Self { + FloatExpand(self.0.round()) + } + + fn trunc(self) -> Self { + FloatExpand(self.0.trunc()) + } + + fn fract(self) -> Self { + FloatExpand(self.0.fract()) + } + + fn abs(self) -> Self { + FloatExpand(self.0.abs()) + } + + fn signum(self) -> Self { + FloatExpand(self.0.signum()) + } + + fn is_sign_positive(self) -> bool { + self.0.is_sign_positive() + } + + fn is_sign_negative(self) -> bool { + self.0.is_sign_negative() + } + + fn mul_add(self, a: Self, b: Self) -> Self { + FloatExpand(self.0.mul_add(a.0, b.0)) + } + + fn recip(self) -> Self { + FloatExpand(self.0.recip()) + } + + fn powi(self, n: i32) -> Self { + FloatExpand(self.0.powi(n)) + } + + fn powf(self, n: Self) -> Self { + FloatExpand(self.0.powf(n.0)) + } + + fn sqrt(self) -> Self { + FloatExpand(self.0.sqrt()) + } + + fn exp(self) -> Self { + FloatExpand(self.0.exp()) + } + + fn exp2(self) -> Self { + FloatExpand(self.0.exp2()) + } + + fn ln(self) -> Self { + FloatExpand(self.0.ln()) + } + + fn log(self, base: Self) -> Self { + FloatExpand(self.0.log(base.0)) + } + + fn log2(self) -> Self { + FloatExpand(self.0.log2()) + } + + fn log10(self) -> Self { + FloatExpand(self.0.log10()) + } + + fn max(self, other: Self) -> Self { + FloatExpand(self.0.max(other.0)) + } + + fn min(self, other: Self) -> Self { + FloatExpand(self.0.min(other.0)) + } + + fn abs_sub(self, other: Self) -> Self { + FloatExpand((self.0 - other.0).abs()) + } + + fn cbrt(self) -> Self { + FloatExpand(self.0.cbrt()) + } + + fn hypot(self, other: Self) -> Self { + FloatExpand(self.0.hypot(other.0)) + } + + fn sin(self) -> Self { + FloatExpand(self.0.sin()) + } + + fn cos(self) -> Self { + FloatExpand(self.0.cos()) + } + + fn tan(self) -> Self { + FloatExpand(self.0.tan()) + } + + fn asin(self) -> Self { + FloatExpand(self.0.asin()) + } + + fn acos(self) -> Self { + FloatExpand(self.0.acos()) + } + + fn atan(self) -> Self { + FloatExpand(self.0.atan()) + } + + fn atan2(self, other: Self) -> Self { + FloatExpand(self.0.atan2(other.0)) + } + + fn sin_cos(self) -> (Self, Self) { + let (a, b) = self.0.sin_cos(); + (FloatExpand(a), FloatExpand(b)) + } + + fn exp_m1(self) -> Self { + FloatExpand(self.0.exp_m1()) + } + + fn ln_1p(self) -> Self { + FloatExpand(self.0.ln_1p()) + } + + fn sinh(self) -> Self { + FloatExpand(self.0.sinh()) + } + + fn cosh(self) -> Self { + FloatExpand(self.0.cosh()) + } + + fn tanh(self) -> Self { + FloatExpand(self.0.tanh()) + } + + fn asinh(self) -> Self { + FloatExpand(self.0.asinh()) + } + + fn acosh(self) -> Self { + FloatExpand(self.0.acosh()) + } + + fn atanh(self) -> Self { + FloatExpand(self.0.atanh()) + } + + fn integer_decode(self) -> (u64, i16, i8) { + self.0.integer_decode() + } +} + +impl Num for FloatExpand { + type FromStrRadixErr = ::FromStrRadixErr; + + fn from_str_radix(str: &str, radix: u32) -> Result { + Ok(FloatExpand(f32::from_str_radix(str, radix)?)) + } +} + +impl One for FloatExpand { + fn one() -> Self { + FloatExpand(1.0) + } +} + +impl Zero for FloatExpand { + fn zero() -> Self { + FloatExpand(0.0) + } + + fn is_zero(&self) -> bool { + self.0 == 0.0 + } +} diff --git a/crates/cubecl-core/src/frontend/element/int.rs b/crates/cubecl-core/src/frontend/element/int.rs index ff2811c41..bab4ed621 100644 --- a/crates/cubecl-core/src/frontend/element/int.rs +++ b/crates/cubecl-core/src/frontend/element/int.rs @@ -1,10 +1,13 @@ -use crate::compute::{KernelBuilder, KernelLauncher}; use crate::frontend::{ CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, Numeric, }; use crate::ir::{Elem, IntKind}; use crate::Runtime; +use crate::{ + compute::{KernelBuilder, KernelLauncher}, + prelude::{CountOnes, ReverseBits}, +}; use super::{ init_expand_element, Init, IntoRuntime, LaunchArgExpand, ScalarArgSettings, __expand_new, @@ -13,6 +16,8 @@ use super::{ /// Signed or unsigned integer. Used as input in int kernels pub trait Int: Numeric + + CountOnes + + ReverseBits + std::ops::Rem + core::ops::Add + core::ops::Sub @@ -51,8 +56,8 @@ macro_rules! impl_int { } impl CubePrimitive for $type { - fn as_elem() -> Elem { - Elem::Int(IntKind::$kind) + fn as_elem_native() -> Option { + Some(Elem::Int(IntKind::$kind)) } } @@ -67,8 +72,12 @@ macro_rules! impl_int { } impl Numeric for $type { - const MAX: Self = $type::MAX; - const MIN: Self = $type::MIN; + fn min_value() -> Self { + $type::MIN + } + fn max_value() -> Self { + $type::MAX + } } impl ExpandElementBaseInit for $type { @@ -92,7 +101,7 @@ macro_rules! impl_int { _: &Self::CompilationArg, builder: &mut KernelBuilder, ) -> ExpandElementTyped { - builder.scalar($type::as_elem()).into() + builder.scalar($type::as_elem(&builder.context)).into() } } }; diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index dc8841764..6ef446542 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -48,8 +48,22 @@ pub trait Numeric: + std::cmp::PartialOrd + std::cmp::PartialEq { - const MAX: Self; - const MIN: Self; + fn min_value() -> Self; + fn max_value() -> Self; + + fn __expand_min_value(context: &mut CubeContext) -> ::ExpandType { + let elem = Self::as_elem(context); + let var = elem.min_variable(); + let expand = ExpandElement::Plain(var); + expand.into() + } + + fn __expand_max_value(context: &mut CubeContext) -> ::ExpandType { + let elem = Self::as_elem(context); + let var = elem.max_variable(); + let expand = ExpandElement::Plain(var); + expand.into() + } /// Create a new constant numeric. /// @@ -68,10 +82,10 @@ pub trait Numeric: } fn __expand_from_int( - _context: &mut CubeContext, + context: &mut CubeContext, val: ExpandElementTyped, ) -> ::ExpandType { - let elem = Self::as_elem(); + let elem = Self::as_elem(context); let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64()); ExpandElement::Plain(var).into() @@ -81,11 +95,11 @@ pub trait Numeric: context: &mut CubeContext, vec: [u32; D], ) -> ::ExpandType { - let new_var = context.create_local_binding(Item::vectorized( - Self::as_elem(), + let new_var = context.create_local(Item::vectorized( + Self::as_elem(context), NonZero::new(vec.len() as u8), )); - let elem = Self::as_elem(); + let elem = Self::as_elem(context); for (i, element) in vec.iter().enumerate() { let var: Variable = elem.constant_from_i64(*element as i64); @@ -94,7 +108,7 @@ pub trait Numeric: index_assign::expand::( context, new_var.clone().into(), - ExpandElementTyped::from_lit(i), + ExpandElementTyped::from_lit(context, i), expand.into(), ); } diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index c089f04c9..f5f669410 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -24,8 +24,8 @@ macro_rules! declare_uint { } impl CubePrimitive for $primitive { - fn as_elem() -> Elem { - Elem::UInt(UIntKind::$kind) + fn as_elem_native() -> Option { + Some(Elem::UInt(UIntKind::$kind)) } } @@ -46,13 +46,17 @@ macro_rules! declare_uint { _: &Self::CompilationArg, builder: &mut KernelBuilder, ) -> ExpandElementTyped { - builder.scalar($primitive::as_elem()).into() + builder.scalar($primitive::as_elem(&builder.context)).into() } } impl Numeric for $primitive { - const MAX: Self = $primitive::MAX; - const MIN: Self = $primitive::MIN; + fn min_value() -> Self { + $primitive::MIN + } + fn max_value() -> Self { + $primitive::MAX + } } impl Int for $primitive { diff --git a/crates/cubecl-core/src/frontend/operation/assignation.rs b/crates/cubecl-core/src/frontend/operation/assignation.rs index 8ca4b3a16..51b273a80 100644 --- a/crates/cubecl-core/src/frontend/operation/assignation.rs +++ b/crates/cubecl-core/src/frontend/operation/assignation.rs @@ -143,7 +143,7 @@ pub mod index { let array: ExpandElement = array.into(); let var: Variable = *array; let var = match var.kind { - VariableKind::Local { .. } | VariableKind::LocalBinding { .. } => { + VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => { binary_expand_no_vec(context, array, index, Operator::Index) } _ => binary_expand(context, array, index, Operator::Index), diff --git a/crates/cubecl-core/src/frontend/operation/base.rs b/crates/cubecl-core/src/frontend/operation/base.rs index 59ca59772..2c1cf308b 100644 --- a/crates/cubecl-core/src/frontend/operation/base.rs +++ b/crates/cubecl-core/src/frontend/operation/base.rs @@ -2,7 +2,7 @@ use std::num::NonZeroU8; use crate::ir::{ BinaryOperator, Elem, Instruction, Item, Operation, Operator, UnaryOperator, Variable, - Vectorization, + VariableKind, Vectorization, }; use crate::prelude::{CubeType, ExpandElementTyped}; use crate::{ @@ -29,7 +29,7 @@ where let item = Item::vectorized(item_lhs.elem, vectorization); - let output = context.create_local_binding(item); + let output = context.create_local(item); let out = *output; let op = func(BinaryOperator { lhs, rhs }); @@ -52,7 +52,7 @@ where let lhs_var = lhs.consume(); let rhs_var = rhs.consume(); - let out = context.create_local_binding(out_item); + let out = context.create_local(out_item); let out_var = *out; @@ -82,7 +82,7 @@ where let item = Item::new(item_lhs.elem); - let output = context.create_local_binding(item); + let output = context.create_local(item); let out = *output; let op = func(BinaryOperator { lhs, rhs }); @@ -112,7 +112,7 @@ where vectorization: item.vectorization, }; - let out = context.create_local_binding(out_item); + let out = context.create_local(out_item); let out_var = *out; let op = func(BinaryOperator { lhs, rhs }); @@ -150,7 +150,7 @@ where let input = input.consume(); let item = input.item; - let out = context.create_local_binding(item); + let out = context.create_local(item); let out_var = *out; let op = func(UnaryOperator { input }); @@ -170,7 +170,7 @@ where F: Fn(UnaryOperator) -> Operator, { let input = input.consume(); - let output = context.create_local_binding(out_item); + let output = context.create_local(out_item); let out = *output; let op = func(UnaryOperator { input }); @@ -191,7 +191,7 @@ where let input_var: Variable = *input; let item = input.item; - let out = context.create_local_variable(item); + let out = context.create_local_mut(item); let out_var = *out; let op = func(input_var); @@ -239,7 +239,12 @@ pub fn array_assign_binary_op_expand< let index: ExpandElement = index.into(); let value: ExpandElement = value.into(); - let array_value = context.create_local_binding(array.item); + let array_item = match array.kind { + // In that case, the array is a line. + VariableKind::LocalMut { .. } => array.item.vectorize(None), + _ => array.item, + }; + let array_value = context.create_local(array_item); let read = Instruction::new( Operator::Index(BinaryOperator { @@ -249,7 +254,7 @@ pub fn array_assign_binary_op_expand< *array_value, ); let array_value = array_value.consume(); - let op_out = context.create_local_binding(array.item); + let op_out = context.create_local(array.item); let calculate = Instruction::new( func(BinaryOperator { lhs: array_value, @@ -262,7 +267,6 @@ pub fn array_assign_binary_op_expand< lhs: *index, rhs: op_out.consume(), }); - context.register(read); context.register(calculate); context.register(Instruction::new(write, *array)); diff --git a/crates/cubecl-core/src/frontend/operation/branch.rs b/crates/cubecl-core/src/frontend/operation/branch.rs index a0c836a35..00a110010 100644 --- a/crates/cubecl-core/src/frontend/operation/branch.rs +++ b/crates/cubecl-core/src/frontend/operation/branch.rs @@ -50,7 +50,7 @@ pub mod select { let vf = Ord::max(vf, then.vectorization_factor()); let vf = Ord::max(vf, or_else.vectorization_factor()); - let output = context.create_local_binding(then.item.vectorize(NonZero::new(vf))); + let output = context.create_local(then.item.vectorize(NonZero::new(vf))); let out = *output; let select = Operator::Select(Select { diff --git a/crates/cubecl-core/src/frontend/operation/fma.rs b/crates/cubecl-core/src/frontend/operation/fma.rs index d90f879e0..874070ea0 100644 --- a/crates/cubecl-core/src/frontend/operation/fma.rs +++ b/crates/cubecl-core/src/frontend/operation/fma.rs @@ -18,7 +18,7 @@ pub fn fma_expand( b: ExpandElement, c: ExpandElement, ) -> ExpandElement { - let output = context.create_local_binding(a.item); + let output = context.create_local(a.item); let out = *output; let a = *a; diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 6412988c1..2d33fc0b4 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -71,6 +71,26 @@ macro_rules! impl_unary_func_fixed_out_vectorization { } } +macro_rules! impl_unary_func_fixed_out_ty { + ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => { + pub trait $trait_name: CubePrimitive + Sized { + #[allow(unused_variables)] + fn $method_name(x: Self) -> $out_ty { + unexpanded!() + } + + fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> { + let expand_element: ExpandElement = x.into(); + let mut item = expand_element.item; + item.elem = <$out_ty as CubePrimitive>::as_elem(context); + unary_expand_fixed_output(context, expand_element, item, $operator).into() + } + } + + $(impl $trait_name for $type {})* + } +} + impl_unary_func!( Abs, abs, @@ -260,3 +280,32 @@ impl_unary_func!( f32, f64 ); +impl_unary_func_fixed_out_ty!( + CountOnes, + count_ones, + __expand_count_ones, + u32, + Operator::CountOnes, + u8, + i8, + u16, + i16, + u32, + i32, + u64, + i64 +); +impl_unary_func!( + ReverseBits, + reverse_bits, + __expand_reverse_bits, + Operator::ReverseBits, + u8, + i8, + u16, + i16, + u32, + i32, + u64, + i64 +); diff --git a/crates/cubecl-core/src/frontend/plane.rs b/crates/cubecl-core/src/frontend/plane.rs index 2ad489556..a01590889 100644 --- a/crates/cubecl-core/src/frontend/plane.rs +++ b/crates/cubecl-core/src/frontend/plane.rs @@ -17,7 +17,7 @@ pub mod plane_elect { /// Expand method of [plane_elect()]. pub fn expand(context: &mut CubeContext) -> ExpandElementTyped { - let output = context.create_local_binding(Item::new(Elem::Bool)); + let output = context.create_local(Item::new(Elem::Bool)); let out = *output; context.register(Instruction::new(Plane::Elect, out)); @@ -44,7 +44,7 @@ pub mod plane_broadcast { value: ExpandElementTyped, id: ExpandElementTyped, ) -> ExpandElementTyped { - let output = context.create_local_binding(value.expand.item); + let output = context.create_local(value.expand.item); let out = *output; let lhs = *value.expand; let rhs = *id.expand; @@ -74,7 +74,7 @@ pub mod plane_sum { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -100,7 +100,7 @@ pub mod plane_prod { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -126,7 +126,7 @@ pub mod plane_max { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -152,7 +152,7 @@ pub mod plane_min { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -179,7 +179,7 @@ pub mod plane_all { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; @@ -206,7 +206,7 @@ pub mod plane_any { elem: ExpandElementTyped, ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); - let output = context.create_local_binding(elem.item); + let output = context.create_local(elem.item); let out = *output; let input = *elem; diff --git a/crates/cubecl-core/src/ir/allocator.rs b/crates/cubecl-core/src/ir/allocator.rs new file mode 100644 index 000000000..d6f4edfd4 --- /dev/null +++ b/crates/cubecl-core/src/ir/allocator.rs @@ -0,0 +1,89 @@ +use std::{collections::HashMap, rc::Rc}; + +use crate::prelude::ExpandElement; + +use super::{Item, Scope, Variable}; + +/// An allocator for local variables of a kernel. +/// +/// A local variable is unique to a unit. That is, each unit have their own copy of a local variable. +/// There are three types of local variables based on their capabilities. +/// - An immutable local variable is obtained by calling [Allocator::create_local]. +/// - A mutable local variable is obtained by calling [Allocator::create_local_mut]. The allocator will reuse +/// previously defined mutable variables if possible. +/// - A restricted mutable local variable is obtained by calling [Allocator::create_local_restricted]. This a is +/// mutable variable that cannot be reused. This is mostly used for loop indices. +/// +/// # Performance tips +/// +/// In order, prefer immutable local variables, then mutable, then restricted. +/// +/// To enable many compiler optimizations, it is prefered to use the [static single-assignment] strategy for immutable variables. +/// That is, each variable must be declared and used exactly once. +/// +/// [static single-assignment](https://en.wikipedia.org/wiki/Static_single-assignment_form) +#[derive(Clone)] +pub struct Allocator { + local_mut_pool: HashMap>, +} + +impl Default for Allocator { + fn default() -> Self { + Self::new() + } +} + +impl Allocator { + /// Create a new allocator. + pub fn new() -> Self { + Self { + local_mut_pool: HashMap::new(), + } + } + + /// Create a new immutable local variable of type specified by `item` for the given `scope`. + pub fn create_local(&self, scope: &mut Scope, item: Item) -> ExpandElement { + ExpandElement::Plain(scope.create_local(item)) + } + + /// Create a new mutable local variable of type specified by `item` for the given `scope`. + /// Try to reuse a previously defined but unused mutable variable in the current scope if possible. + /// Else, this define a new variable. + pub fn create_local_mut(&mut self, scope: &mut Scope, item: Item) -> ExpandElement { + if item.elem.is_atomic() { + ExpandElement::Plain(scope.create_local_restricted(item)) + } else { + self.reuse_local_mut(item) + .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(scope, item))) + } + } + + /// Create a new mutable restricted local variable of type specified by `item` into the given `scope`. + pub fn create_local_restricted(&self, scope: &mut Scope, item: Item) -> ExpandElement { + ExpandElement::Plain(scope.create_local_restricted(item)) + } + + // Try to return a reusable mutable variable for the given `item` or `None` otherwise. + fn reuse_local_mut(&self, item: Item) -> Option { + // Among the candidates, take a variable if it's only referenced by the pool. + // Arbitrarily takes the first it finds in reversed order. + self.local_mut_pool.get(&item).and_then(|vars| { + vars.iter() + .rev() + .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1)) + .cloned() + }) + } + + /// Add a new variable to the pool with type specified by `item` for the given `scope`. + fn add_local_mut(&mut self, scope: &mut Scope, item: Item) -> Rc { + let var = Rc::new(scope.create_local_mut(item)); + let expand = ExpandElement::Managed(var.clone()); + if let Some(variables) = self.local_mut_pool.get_mut(&item) { + variables.push(expand); + } else { + self.local_mut_pool.insert(var.item, vec![expand]); + } + var + } +} diff --git a/crates/cubecl-core/src/ir/branch.rs b/crates/cubecl-core/src/ir/branch.rs index 1470e2c87..85fd6d120 100644 --- a/crates/cubecl-core/src/ir/branch.rs +++ b/crates/cubecl-core/src/ir/branch.rs @@ -1,8 +1,6 @@ use std::fmt::Display; -use crate::prelude::CubePrimitive; - -use super::{Item, Scope, Variable}; +use super::{Elem, Item, Scope, UIntKind, Variable}; use serde::{Deserialize, Serialize}; /// All branching types. @@ -151,8 +149,8 @@ impl RangeLoop { func: F, ) { let mut scope = parent_scope.child(); - let index_ty = Item::new(u32::as_elem()); - let i = scope.create_local_undeclared(index_ty); + let index_ty = Item::new(Elem::UInt(UIntKind::U32)); + let i = scope.create_local_restricted(index_ty); func(i, &mut scope); diff --git a/crates/cubecl-core/src/ir/kernel.rs b/crates/cubecl-core/src/ir/kernel.rs index f62c5d045..19d97029c 100644 --- a/crates/cubecl-core/src/ir/kernel.rs +++ b/crates/cubecl-core/src/ir/kernel.rs @@ -186,6 +186,94 @@ impl Elem { Elem::Int(_) | Elem::AtomicInt(_) | Elem::UInt(_) | Elem::AtomicUInt(_) ) } + + pub fn max_variable(&self) -> Variable { + let value = match self { + Elem::Float(kind) => match kind { + FloatKind::F16 => { + ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16) + } + FloatKind::BF16 => { + ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16) + } + FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32), + FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32), + FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32), + FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64), + }, + Elem::Int(kind) => match kind { + IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8), + IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16), + IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32), + IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64), + }, + Elem::AtomicInt(kind) => match kind { + IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8), + IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16), + IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32), + IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64), + }, + Elem::UInt(kind) => match kind { + UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8), + UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16), + UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32), + UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64), + }, + Elem::AtomicUInt(kind) => match kind { + UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8), + UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16), + UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32), + UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64), + }, + Elem::Bool => ConstantScalarValue::Bool(true), + }; + + Variable::new(VariableKind::ConstantScalar(value), Item::new(*self)) + } + + pub fn min_variable(&self) -> Variable { + let value = match self { + Elem::Float(kind) => match kind { + FloatKind::F16 => { + ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16) + } + FloatKind::BF16 => { + ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16) + } + FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32), + FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32), + FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32), + FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64), + }, + Elem::Int(kind) => match kind { + IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8), + IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16), + IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32), + IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64), + }, + Elem::AtomicInt(kind) => match kind { + IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8), + IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16), + IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32), + IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64), + }, + Elem::UInt(kind) => match kind { + UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8), + UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16), + UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32), + UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64), + }, + Elem::AtomicUInt(kind) => match kind { + UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8), + UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16), + UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32), + UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64), + }, + Elem::Bool => ConstantScalarValue::Bool(false), + }; + + Variable::new(VariableKind::ConstantScalar(value), Item::new(*self)) + } } impl From for Item { diff --git a/crates/cubecl-core/src/ir/local_allocator.rs b/crates/cubecl-core/src/ir/local_allocator.rs deleted file mode 100644 index 5e25febff..000000000 --- a/crates/cubecl-core/src/ir/local_allocator.rs +++ /dev/null @@ -1,143 +0,0 @@ -use std::{ - cell::RefCell, - collections::HashMap, - rc::Rc, - sync::atomic::{AtomicU16, Ordering}, -}; - -use crate::prelude::ExpandElement; - -use super::{Item, Scope, Variable, VariableKind}; - -type ScopeRef = Rc>; - -/// Defines a local variable allocation strategy (i.e. reused mutable variables, SSA) -pub trait LocalAllocator { - /// Creates a local variable that can be (re)assigned - fn create_local_variable(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement; - /// Creates an immutable local binding for intermediates - fn create_local_binding(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement; - /// Creates an undeclared local binding that must not be reused regardless of allocator - fn create_local_undeclared(&self, root: ScopeRef, scope: ScopeRef, item: Item) - -> ExpandElement; -} - -#[derive(Default, Clone)] -pub struct VariablePool { - map: Rc>>>, -} - -impl VariablePool { - /// Returns an old, not used anymore variable, if there exists one. - pub fn reuse(&self, item: Item) -> Option { - let map = self.map.borrow(); - - // Filter for candidate variables of the same Item - let variables = map.get(&item)?; - - // Among the candidates, take a variable if it's only referenced by the map - // Arbitrarily takes the first it finds in reverse order. - for variable in variables.iter().rev() { - match variable { - ExpandElement::Managed(var) => { - if Rc::strong_count(var) == 1 { - return Some(variable.clone()); - } - } - ExpandElement::Plain(_) => (), - } - } - - // If no candidate was found, a new var will be needed - None - } - - /// Insert a new variable in the map, which is classified by Item - pub fn insert(&self, var: ExpandElement) { - let mut map = self.map.borrow_mut(); - let item = var.item; - - if let Some(variables) = map.get_mut(&item) { - variables.push(var.clone()); - } else { - map.insert(var.item, vec![var.clone()]); - } - } -} - -/// Reusing allocator, assigns all intermediates to a set of mutable variables that get continuously -/// reused. -#[derive(Default)] -pub struct ReusingAllocator { - pool: VariablePool, -} - -impl LocalAllocator for ReusingAllocator { - fn create_local_variable(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - if item.elem.is_atomic() { - let new = scope.borrow_mut().create_local_undeclared(item); - return ExpandElement::Plain(new); - } - - // Reuse an old variable if possible - if let Some(var) = self.pool.reuse(item) { - return var; - } - - // Create a new variable at the root scope - // Insert it in the variable pool for potential reuse - let new = ExpandElement::Managed(Rc::new(root.borrow_mut().create_local(item))); - self.pool.insert(new.clone()); - - new - } - - fn create_local_binding(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - self.create_local_variable(root, scope, item) - } - - fn create_local_undeclared( - &self, - _root: ScopeRef, - scope: ScopeRef, - item: Item, - ) -> ExpandElement { - ExpandElement::Plain(scope.borrow_mut().create_local_undeclared(item)) - } -} - -/// Hybrid allocator. Creates immutable local bindings for intermediates, and falls back to -/// [`ReusingAllocator`] for mutable variables. -#[derive(Default)] -pub struct HybridAllocator { - variable_allocator: ReusingAllocator, - ssa_index: AtomicU16, -} - -impl LocalAllocator for HybridAllocator { - fn create_local_variable(&self, root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - self.ssa_index.fetch_add(1, Ordering::AcqRel); - self.variable_allocator - .create_local_variable(root, scope, item) - } - - fn create_local_binding(&self, _root: ScopeRef, scope: ScopeRef, item: Item) -> ExpandElement { - let id = self.ssa_index.fetch_add(1, Ordering::AcqRel); - let depth = scope.borrow().depth; - ExpandElement::Plain(Variable::new( - VariableKind::LocalBinding { id, depth }, - item, - )) - } - - fn create_local_undeclared( - &self, - _root: ScopeRef, - scope: ScopeRef, - item: Item, - ) -> ExpandElement { - let id = self.ssa_index.fetch_add(1, Ordering::AcqRel); - let depth = scope.borrow().depth; - ExpandElement::Plain(Variable::new(VariableKind::Local { id, depth }, item)) - } -} diff --git a/crates/cubecl-core/src/ir/mod.rs b/crates/cubecl-core/src/ir/mod.rs index 820d80586..281488139 100644 --- a/crates/cubecl-core/src/ir/mod.rs +++ b/crates/cubecl-core/src/ir/mod.rs @@ -1,9 +1,9 @@ +mod allocator; mod branch; mod cmma; mod comment; mod debug; mod kernel; -mod local_allocator; mod macros; mod operation; mod plane; @@ -13,12 +13,12 @@ mod synchronization; mod variable; pub use super::frontend::AtomicOp; +pub use allocator::*; pub use branch::*; pub use cmma::*; pub use comment::*; pub use debug::*; pub use kernel::*; -pub use local_allocator::*; pub use operation::*; pub use plane::*; pub use scope::*; diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index 80878279a..04b886cee 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -169,6 +169,8 @@ pub enum Operator { BitwiseXor(BinaryOperator), ShiftLeft(BinaryOperator), ShiftRight(BinaryOperator), + CountOnes(UnaryOperator), + ReverseBits(UnaryOperator), Remainder(BinaryOperator), Bitcast(UnaryOperator), Magnitude(UnaryOperator), @@ -236,6 +238,8 @@ impl Display for Operator { Operator::BitwiseAnd(op) => write!(f, "{} & {}", op.lhs, op.rhs), Operator::BitwiseOr(op) => write!(f, "{} | {}", op.lhs, op.rhs), Operator::BitwiseXor(op) => write!(f, "{} ^ {}", op.lhs, op.rhs), + Operator::CountOnes(op) => write!(f, "{}.count_bits()", op.input), + Operator::ReverseBits(op) => write!(f, "{}.reverse_bits()", op.input), Operator::ShiftLeft(op) => write!(f, "{} << {}", op.lhs, op.rhs), Operator::ShiftRight(op) => write!(f, "{} >> {}", op.lhs, op.rhs), Operator::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs), @@ -369,25 +373,6 @@ pub struct FmaOperator { pub c: Variable, } -#[allow(missing_docs)] -pub fn expand_checked_index(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) { - let array_len = scope.create_local(Item::new(Elem::UInt(UIntKind::U32))); - let inside_bound = scope.create_local(Item::new(Elem::Bool)); - let item = scope.create_local(out.item); - let zero: Variable = 0u32.into(); - - if lhs.has_buffer_length() { - cpa!(scope, array_len = buffer_len(lhs)); - } else { - cpa!(scope, array_len = len(lhs)); - } - - cpa!(scope, inside_bound = rhs < array_len); - - cpa!(scope, item = unchecked(lhs[rhs])); - cpa!(scope, out = select(inside_bound, item, zero)); -} - #[allow(missing_docs)] pub fn expand_checked_index_assign(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) { let array_len = scope.create_local(Item::new(Elem::UInt(UIntKind::U32))); diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index 4a5f5bbaf..85e081f60 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -1,7 +1,8 @@ -use crate::prelude::{AtomicOp, CubePrimitive}; +use crate::prelude::AtomicOp; use super::{ - Branch, CoopMma, Elem, Instruction, Metadata, Operation, Operator, Variable, VariableKind, + Branch, CoopMma, Elem, Instruction, Metadata, Operation, Operator, UIntKind, Variable, + VariableKind, }; /// Information necessary when compiling a scope. @@ -133,23 +134,23 @@ impl ScopeProcessing { } Operator::Slice(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); - sanitize_constant_scalar_ref_elem(&mut op.start, u32::as_elem()); - sanitize_constant_scalar_ref_elem(&mut op.end, u32::as_elem()); + sanitize_constant_scalar_ref_elem(&mut op.start, Elem::UInt(UIntKind::U32)); + sanitize_constant_scalar_ref_elem(&mut op.end, Elem::UInt(UIntKind::U32)); } Operator::Index(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); - sanitize_constant_scalar_ref_elem(&mut op.rhs, u32::as_elem()); + sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt(UIntKind::U32)); } Operator::UncheckedIndex(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); - sanitize_constant_scalar_ref_elem(&mut op.rhs, u32::as_elem()); + sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt(UIntKind::U32)); } Operator::IndexAssign(op) => { - sanitize_constant_scalar_ref_elem(&mut op.lhs, u32::as_elem()); + sanitize_constant_scalar_ref_elem(&mut op.lhs, Elem::UInt(UIntKind::U32)); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); } Operator::UncheckedIndexAssign(op) => { - sanitize_constant_scalar_ref_elem(&mut op.lhs, u32::as_elem()); + sanitize_constant_scalar_ref_elem(&mut op.lhs, Elem::UInt(UIntKind::U32)); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); } Operator::And(op) => { @@ -186,6 +187,12 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); } + Operator::CountOnes(_) => { + // Nothing to do + } + Operator::ReverseBits(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Operator::ShiftLeft(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); @@ -214,13 +221,25 @@ impl ScopeProcessing { } Operator::CopyMemory(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); - sanitize_constant_scalar_ref_elem(&mut op.in_index, u32::as_elem()); - sanitize_constant_scalar_ref_elem(&mut op.out_index, u32::as_elem()); + sanitize_constant_scalar_ref_elem( + &mut op.in_index, + Elem::UInt(UIntKind::U32), + ); + sanitize_constant_scalar_ref_elem( + &mut op.out_index, + Elem::UInt(UIntKind::U32), + ); } Operator::CopyMemoryBulk(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); - sanitize_constant_scalar_ref_elem(&mut op.in_index, u32::as_elem()); - sanitize_constant_scalar_ref_elem(&mut op.out_index, u32::as_elem()); + sanitize_constant_scalar_ref_elem( + &mut op.in_index, + Elem::UInt(UIntKind::U32), + ); + sanitize_constant_scalar_ref_elem( + &mut op.out_index, + Elem::UInt(UIntKind::U32), + ); } Operator::Select(op) => { sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool); @@ -263,10 +282,10 @@ impl ScopeProcessing { }, Operation::Metadata(op) => match op { Metadata::Stride { dim, .. } => { - sanitize_constant_scalar_ref_elem(dim, u32::as_elem()); + sanitize_constant_scalar_ref_elem(dim, Elem::UInt(UIntKind::U32)); } Metadata::Shape { dim, .. } => { - sanitize_constant_scalar_ref_elem(dim, u32::as_elem()); + sanitize_constant_scalar_ref_elem(dim, Elem::UInt(UIntKind::U32)); } Metadata::Length { .. } | Metadata::BufferLength { .. } @@ -285,7 +304,7 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.end, &op.start); sanitize_constant_scalar_ref_var(&mut op.i, &op.start); if let Some(step) = &mut op.step { - sanitize_constant_scalar_ref_elem(step, u32::as_elem()); + sanitize_constant_scalar_ref_elem(step, Elem::UInt(UIntKind::U32)); } } _ => { @@ -304,13 +323,13 @@ impl ScopeProcessing { } CoopMma::Load { value, stride, .. } => { sanitize_constant_scalar_ref_var(value, &inst.out.unwrap()); - sanitize_constant_scalar_ref_elem(stride, u32::as_elem()); + sanitize_constant_scalar_ref_elem(stride, Elem::UInt(UIntKind::U32)); } CoopMma::Execute { .. } => { // Nothing to do. } CoopMma::Store { stride, .. } => { - sanitize_constant_scalar_ref_elem(stride, u32::as_elem()); + sanitize_constant_scalar_ref_elem(stride, Elem::UInt(UIntKind::U32)); } CoopMma::Cast { .. } => { // Nothing to do. diff --git a/crates/cubecl-core/src/ir/scope.rs b/crates/cubecl-core/src/ir/scope.rs index cf9b5d5ca..9ed3923bc 100644 --- a/crates/cubecl-core/src/ir/scope.rs +++ b/crates/cubecl-core/src/ir/scope.rs @@ -1,8 +1,8 @@ -use crate::{ir::ConstantScalarValue, prelude::CubePrimitive}; +use crate::ir::ConstantScalarValue; use super::{ - cpa, processing::ScopeProcessing, Elem, Instruction, Item, Matrix, Operation, Variable, - VariableKind, + cpa, processing::ScopeProcessing, Elem, Instruction, Item, Matrix, Operation, UIntKind, + Variable, VariableKind, }; use serde::{Deserialize, Serialize}; @@ -67,7 +67,7 @@ impl Scope { /// Create a variable initialized at zero. pub fn zero>(&mut self, item: I) -> Variable { - let local = self.create_local(item); + let local = self.create_local(item.into()); let zero: Variable = 0u32.into(); cpa!(self, local = zero); local @@ -123,12 +123,12 @@ impl Scope { variable } - /// Create a local variable of the given [item type](Item). - pub fn create_local>(&mut self, item: I) -> Variable { + /// Create a mutable variable of the given [item type](Item). + pub fn create_local_mut>(&mut self, item: I) -> Variable { let item = item.into(); let index = self.new_local_index(); let local = Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: index, depth: self.depth, }, @@ -138,12 +138,11 @@ impl Scope { local } - /// Create a new undeclared local, but doesn't perform the declaration. - /// + /// Create a new restricted variable. The variable is /// Useful for _for loops_ and other algorithms that require the control over initialization. - pub fn create_local_undeclared(&mut self, item: Item) -> Variable { + pub fn create_local_restricted(&mut self, item: Item) -> Variable { let index = self.new_local_index(); - let local = VariableKind::Local { + let local = VariableKind::LocalMut { id: index, depth: self.depth, }; @@ -151,12 +150,10 @@ impl Scope { Variable::new(local, item) } - /// Create a new undeclared local binding, but doesn't perform the declaration. - /// - /// Useful for temporaries and other algorithms that require the control over initialization. - pub fn create_local_binding(&mut self, item: Item) -> Variable { + /// Create a new immutable variable. + pub fn create_local(&mut self, item: Item) -> Variable { let index = self.new_local_index(); - let local = VariableKind::LocalBinding { + let local = VariableKind::LocalConst { id: index, depth: self.depth, }; @@ -181,7 +178,7 @@ impl Scope { /// The index refers to the scalar position for the same [element](Elem) type. pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable { let local = Variable::new( - VariableKind::LocalBinding { + VariableKind::LocalConst { id: self.new_local_index(), depth: self.depth, }, @@ -336,7 +333,7 @@ impl Scope { ) -> Variable { let item_global = match item.elem() { Elem::Bool => Item { - elem: u32::as_elem(), + elem: Elem::UInt(UIntKind::U32), vectorization: item.vectorization, }, _ => item, @@ -344,7 +341,7 @@ impl Scope { let input = Variable::new(VariableKind::GlobalInputArray(index), item_global); let index = self.new_local_index(); let local = Variable::new( - VariableKind::Local { + VariableKind::LocalMut { id: index, depth: self.depth, }, diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index 17e0e7efc..41317acab 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -1,8 +1,6 @@ use std::fmt::Display; use std::num::NonZero; -use crate::prelude::CubePrimitive; - use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind}; use serde::{Deserialize, Serialize}; @@ -19,7 +17,10 @@ impl Variable { } pub fn builtin(builtin: Builtin) -> Self { - Self::new(VariableKind::Builtin(builtin), Item::new(u32::as_elem())) + Self::new( + VariableKind::Builtin(builtin), + Item::new(Elem::UInt(UIntKind::U32)), + ) } pub fn constant(scalar: ConstantScalarValue) -> Self { @@ -40,13 +41,13 @@ pub enum VariableKind { GlobalInputArray(Id), GlobalOutputArray(Id), GlobalScalar(Id), - Local { id: Id, depth: u8 }, + LocalArray { id: Id, depth: u8, length: u32 }, + LocalMut { id: Id, depth: u8 }, + LocalConst { id: Id, depth: u8 }, Versioned { id: Id, depth: u8, version: u16 }, - LocalBinding { id: Id, depth: u8 }, ConstantScalar(ConstantScalarValue), ConstantArray { id: Id, length: u32 }, SharedMemory { id: Id, length: u32 }, - LocalArray { id: Id, depth: u8, length: u32 }, Matrix { id: Id, mat: Matrix, depth: u8 }, Slice { id: Id, depth: u8 }, Builtin(Builtin), @@ -84,7 +85,7 @@ impl Variable { pub fn is_immutable(&self) -> bool { match self.kind { VariableKind::GlobalOutputArray { .. } => false, - VariableKind::Local { .. } => false, + VariableKind::LocalMut { .. } => false, VariableKind::SharedMemory { .. } => false, VariableKind::Matrix { .. } => false, VariableKind::Slice { .. } => false, @@ -92,7 +93,7 @@ impl Variable { VariableKind::GlobalInputArray { .. } => false, VariableKind::GlobalScalar { .. } => true, VariableKind::Versioned { .. } => true, - VariableKind::LocalBinding { .. } => true, + VariableKind::LocalConst { .. } => true, VariableKind::ConstantScalar(_) => true, VariableKind::ConstantArray { .. } => true, VariableKind::Builtin(_) => true, @@ -368,21 +369,33 @@ impl Variable { pub fn vectorization_factor(&self) -> u8 { self.item.vectorization.map(NonZero::get).unwrap_or(1u8) } - pub fn index(&self) -> Option { + + pub fn index(&self) -> Option { match self.kind { - VariableKind::GlobalInputArray(id) => Some(id), - VariableKind::GlobalScalar(id) => Some(id), - VariableKind::Local { id, .. } => Some(id), - VariableKind::Versioned { id, .. } => Some(id), - VariableKind::LocalBinding { id, .. } => Some(id), - VariableKind::Slice { id, .. } => Some(id), - VariableKind::GlobalOutputArray(id) => Some(id), - VariableKind::ConstantScalar(_) => None, - VariableKind::ConstantArray { id, .. } => Some(id), - VariableKind::SharedMemory { id, .. } => Some(id), - VariableKind::LocalArray { id, .. } => Some(id), - VariableKind::Matrix { id, .. } => Some(id), - VariableKind::Builtin(_) => None, + VariableKind::GlobalInputArray(id) + | VariableKind::GlobalScalar(id) + | VariableKind::LocalMut { id, .. } + | VariableKind::Versioned { id, .. } + | VariableKind::LocalConst { id, .. } + | VariableKind::Slice { id, .. } + | VariableKind::GlobalOutputArray(id) + | VariableKind::ConstantArray { id, .. } + | VariableKind::SharedMemory { id, .. } + | VariableKind::LocalArray { id, .. } + | VariableKind::Matrix { id, .. } => Some(id), + _ => None, + } + } + + pub fn depth(&self) -> Option { + match self.kind { + VariableKind::LocalMut { depth, .. } + | VariableKind::Versioned { depth, .. } + | VariableKind::LocalConst { depth, .. } + | VariableKind::LocalArray { depth, .. } + | VariableKind::Matrix { depth, .. } + | VariableKind::Slice { depth, .. } => Some(depth), + _ => None, } } @@ -401,11 +414,11 @@ impl Display for Variable { VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"), VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"), VariableKind::ConstantScalar(constant) => write!(f, "{constant}"), - VariableKind::Local { id, depth } => write!(f, "local({id}, {depth})"), + VariableKind::LocalMut { id, depth } => write!(f, "local({id}, {depth})"), VariableKind::Versioned { id, depth, version } => { write!(f, "local({id}, {depth}).v{version}") } - VariableKind::LocalBinding { id, depth } => write!(f, "binding({id}, {depth})"), + VariableKind::LocalConst { id, depth } => write!(f, "binding({id}, {depth})"), VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"), VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"), VariableKind::LocalArray { id, .. } => write!(f, "array({id})"), diff --git a/crates/cubecl-core/src/pod.rs b/crates/cubecl-core/src/pod.rs index f7f3df4d8..ba3603107 100644 --- a/crates/cubecl-core/src/pod.rs +++ b/crates/cubecl-core/src/pod.rs @@ -1,7 +1,6 @@ use crate::{ flex32, ir::{Elem, FloatKind, IntKind, UIntKind}, - prelude::Numeric, }; /// The base element trait for the jit backend. @@ -294,9 +293,9 @@ impl CubeElement for flex32 { Elem::Float(FloatKind::Flex32) } fn maximum_value() -> Self { - flex32::MAX + ::max_value() } fn minimum_value() -> Self { - flex32::MIN + ::min_value() } } diff --git a/crates/cubecl-core/src/prelude.rs b/crates/cubecl-core/src/prelude.rs index a9918911f..2c846189e 100644 --- a/crates/cubecl-core/src/prelude.rs +++ b/crates/cubecl-core/src/prelude.rs @@ -9,8 +9,8 @@ pub use crate::runtime::Runtime; /// Elements pub use crate::frontend::{ - Array, ArrayHandleRef, AtomicI32, AtomicI64, AtomicU32, Float, LaunchArg, Slice, SliceMut, - Tensor, TensorArg, + Array, ArrayHandleRef, AtomicI32, AtomicI64, AtomicU32, Float, FloatExpand, LaunchArg, + NumericExpand, Slice, SliceMut, Tensor, TensorArg, }; pub use crate::pod::CubeElement; diff --git a/crates/cubecl-core/src/runtime_tests/index.rs b/crates/cubecl-core/src/runtime_tests/index.rs index a1c62f3c3..95b74899f 100644 --- a/crates/cubecl-core/src/runtime_tests/index.rs +++ b/crates/cubecl-core/src/runtime_tests/index.rs @@ -1,4 +1,4 @@ -use crate as cubecl; +use crate::{self as cubecl, as_type}; use cubecl::prelude::*; @@ -10,18 +10,20 @@ pub fn kernel_assign(output: &mut Array) { output[0] = item; // out of bounds write should not show up in the array. - output[2] = F::new(10.0); + output[3] = F::new(10.0); // out of bounds read should be read as 0. - output[1] = output[2]; + output[1] = output[3]; } } pub fn test_kernel_index_scalar( client: ComputeClient, ) { - let handle = client.create(F::as_bytes(&[F::new(0.0), F::new(1.0), F::new(123.0)])); - let handle_slice = handle.clone().offset_end(1); + let handle = client.create(F::as_bytes(as_type![F: 0.0, 1.0, 123.0, 6.0])); + let handle_slice = handle + .clone() + .offset_end(F::as_elem_native_unchecked().size() as u64); let vectorization = 1; kernel_assign::launch::( diff --git a/crates/cubecl-core/src/runtime_tests/launch.rs b/crates/cubecl-core/src/runtime_tests/launch.rs index 0cb0d034a..f6020412a 100644 --- a/crates/cubecl-core/src/runtime_tests/launch.rs +++ b/crates/cubecl-core/src/runtime_tests/launch.rs @@ -2,19 +2,19 @@ use crate::{self as cubecl, as_bytes}; use cubecl::prelude::*; #[derive(CubeLaunch)] -pub struct ComptimeTag { - array: Array, +pub struct ComptimeTag { + array: Array, #[cube(comptime)] tag: String, } #[cube(launch)] -pub fn kernel_with_comptime_tag(output: &mut ComptimeTag) { +pub fn kernel_with_comptime_tag(output: &mut ComptimeTag) { if UNIT_POS == 0 { if comptime![&output.tag == "zero"] { - output.array[0] = F::new(0.0); + output.array[0] = f32::new(0.0); } else { - output.array[0] = F::new(1.0); + output.array[0] = f32::new(1.0); } } } @@ -33,13 +33,11 @@ pub fn kernel_without_generics(output: &mut Array) { } } -pub fn test_kernel_with_comptime_tag( - client: ComputeClient, -) { - let handle = client.create(as_bytes![F: 5.0]); - let array_arg = unsafe { ArrayArg::from_raw_parts::(&handle, 1, 1) }; +pub fn test_kernel_with_comptime_tag(client: ComputeClient) { + let handle = client.create(f32::as_bytes(&[5.0])); + let array_arg = unsafe { ArrayArg::from_raw_parts::(&handle, 1, 1) }; - kernel_with_comptime_tag::launch::( + kernel_with_comptime_tag::launch::( &client, CubeCount::Static(1, 1, 1), CubeDim::default(), @@ -47,14 +45,14 @@ pub fn test_kernel_with_comptime_tag( ); let actual = client.read_one(handle.binding()); - let actual = F::from_bytes(&actual); + let actual = f32::from_bytes(&actual); - assert_eq!(actual[0], F::new(0.0)); + assert_eq!(actual[0], f32::new(0.0)); - let handle = client.create(as_bytes![F: 5.0]); - let array_arg = unsafe { ArrayArg::from_raw_parts::(&handle, 1, 1) }; + let handle = client.create(f32::as_bytes(&[5.0])); + let array_arg = unsafe { ArrayArg::from_raw_parts::(&handle, 1, 1) }; - kernel_with_comptime_tag::launch::( + kernel_with_comptime_tag::launch::( &client, CubeCount::Static(1, 1, 1), CubeDim::default(), @@ -62,9 +60,9 @@ pub fn test_kernel_with_comptime_tag( ); let actual = client.read_one(handle.binding()); - let actual = F::from_bytes(&actual); + let actual = f32::from_bytes(&actual); - assert_eq!(actual[0], F::new(1.0)); + assert_eq!(actual[0], f32::new(1.0)); } pub fn test_kernel_with_generics( @@ -124,10 +122,9 @@ macro_rules! testgen_launch { #[test] fn test_launch_with_comptime_tag() { let client = TestRuntime::client(&Default::default()); - cubecl_core::runtime_tests::launch::test_kernel_with_comptime_tag::< - TestRuntime, - FloatType, - >(client); + cubecl_core::runtime_tests::launch::test_kernel_with_comptime_tag::( + client, + ); } }; } diff --git a/crates/cubecl-core/src/runtime_tests/line.rs b/crates/cubecl-core/src/runtime_tests/line.rs index 3ba67ff3d..6454cd36b 100644 --- a/crates/cubecl-core/src/runtime_tests/line.rs +++ b/crates/cubecl-core/src/runtime_tests/line.rs @@ -12,7 +12,7 @@ pub fn kernel_line_index(output: &mut Array, #[comptime] line_size: pub fn test_line_index( client: ComputeClient, ) { - for line_size in R::line_size_elem(&F::as_elem()) { + for line_size in R::line_size_elem(&F::as_elem_native().unwrap()) { let handle = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); unsafe { kernel_line_index::launch_unchecked::( @@ -45,7 +45,7 @@ pub fn kernel_line_index_assign(output: &mut Array>) { pub fn test_line_index_assign( client: ComputeClient, ) { - for line_size in R::line_size_elem(&F::as_elem()) { + for line_size in R::line_size_elem(&F::as_elem_native().unwrap()) { let handle = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); unsafe { kernel_line_index_assign::launch_unchecked::( @@ -66,6 +66,44 @@ pub fn test_line_index_assign( } } +#[cube(launch_unchecked)] +pub fn kernel_line_loop_unroll(output: &mut Array>, #[comptime] line_size: u32) { + if UNIT_POS == 0 { + let mut line = output[0]; + #[unroll] + for k in 0..line_size { + line[k] += F::cast_from(k); + } + output[0] = line; + } +} + +pub fn test_line_loop_unroll( + client: ComputeClient, +) { + for line_size in R::line_size_elem(&F::as_elem(&Default::default())) { + let handle = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); + unsafe { + kernel_line_loop_unroll::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_single(), + ArrayArg::from_raw_parts::(&handle, 1, line_size), + line_size as u32, + ); + } + + let actual = client.read_one(handle.binding()); + let actual = F::from_bytes(&actual); + + let expected = (0..line_size as i64) + .map(|x| F::from_int(x)) + .collect::>(); + + assert_eq!(actual, expected); + } +} + macro_rules! impl_line_comparison { ($cmp:ident, $expected:expr) => { ::paste::paste! { @@ -134,6 +172,14 @@ macro_rules! testgen_line { ); } + #[test] + fn test_line_loop_unroll() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::line::test_line_loop_unroll::( + client, + ); + } + #[test] fn test_line_equal() { let client = TestRuntime::client(&Default::default()); diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 8b010b7c1..e4381153c 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -15,6 +15,7 @@ pub mod sequence; pub mod slice; pub mod tensor; pub mod topology; +pub mod traits; pub mod unary; #[allow(missing_docs)] @@ -36,6 +37,8 @@ macro_rules! testgen_all { ::paste::paste! { $(mod [<$float _ty>] { + use super::*; + type FloatType = $float; type IntType = $i_def; type UintType = $u_def; @@ -43,6 +46,8 @@ macro_rules! testgen_all { $crate::testgen_float!(); })* $(mod [<$int _ty>] { + use super::*; + type FloatType = $f_def; type IntType = $int; type UintType = $u_def; @@ -50,6 +55,8 @@ macro_rules! testgen_all { $crate::testgen_int!(); })* $(mod [<$uint _ty>] { + use super::*; + type FloatType = $f_def; type IntType = $i_def; type UintType = $uint; @@ -83,7 +90,9 @@ macro_rules! testgen_float { #[allow(missing_docs)] #[macro_export] macro_rules! testgen_int { - () => {}; + () => { + cubecl_core::testgen_unary_int!(); + }; } #[allow(missing_docs)] diff --git a/crates/cubecl-core/src/runtime_tests/traits.rs b/crates/cubecl-core/src/runtime_tests/traits.rs new file mode 100644 index 000000000..f57273fde --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/traits.rs @@ -0,0 +1,10 @@ +use crate::{self as cubecl}; +use cubecl::prelude::*; + +#[cube] +pub(crate) trait UnaryOp: 'static + Send + Sync { + type Options: LaunchArg; +} + +#[cube(launch)] +pub(crate) fn associated_type_input>(_options: &O::Options) {} diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 90eb14ddd..544ee5386 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -70,7 +70,98 @@ macro_rules! test_unary_impl { ) }; - assert_equals_approx::(&client, output_handle, $expected, $float_type::new(0.02)); + assert_equals_approx::(&client, output_handle, $expected, $float_type::new(0.02)); + } + )* + } + }; +} + +macro_rules! test_unary_impl_int { + ( + $test_name:ident, + $int_type:ident, + $unary_func:expr, + [$({ + input_vectorization: $input_vectorization:expr, + out_vectorization: $out_vectorization:expr, + input: $input:expr, + expected: $expected:expr + }),*]) => { + pub fn $test_name(client: ComputeClient) { + #[cube(launch_unchecked)] + fn test_function<$int_type: Int>(input: &Array<$int_type>, output: &mut Array<$int_type>) { + if ABSOLUTE_POS < input.len() { + output[ABSOLUTE_POS] = $unary_func(input[ABSOLUTE_POS]); + } + } + + $( + { + let input = $input; + let output_handle = client.empty(input.len() * core::mem::size_of::<$int_type>()); + let input_handle = client.create($int_type::as_bytes(input)); + + unsafe { + test_function::launch_unchecked::<$int_type, R>( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new((input.len() / $input_vectorization as usize) as u32, 1, 1), + ArrayArg::from_raw_parts::<$int_type>(&input_handle, input.len(), $input_vectorization), + ArrayArg::from_raw_parts::<$int_type>(&output_handle, $expected.len(), $out_vectorization), + ) + }; + + let actual = client.read_one(output_handle.binding()); + let actual = $int_type::from_bytes(&actual); + + assert_eq!(actual, $expected); + } + )* + } + }; +} + +macro_rules! test_unary_impl_int_fixed { + ( + $test_name:ident, + $int_type:ident, + $out_type:ident, + $unary_func:expr, + [$({ + input_vectorization: $input_vectorization:expr, + out_vectorization: $out_vectorization:expr, + input: $input:expr, + expected: $expected:expr + }),*]) => { + pub fn $test_name(client: ComputeClient) { + #[cube(launch_unchecked)] + fn test_function<$int_type: Int>(input: &Array<$int_type>, output: &mut Array<$out_type>) { + if ABSOLUTE_POS < input.len() { + output[ABSOLUTE_POS] = $unary_func(input[ABSOLUTE_POS]); + } + } + + $( + { + let input = $input; + let output_handle = client.empty(input.len() * core::mem::size_of::<$out_type>()); + let input_handle = client.create($int_type::as_bytes(input)); + + unsafe { + test_function::launch_unchecked::<$int_type, R>( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new((input.len() / $input_vectorization as usize) as u32, 1, 1), + ArrayArg::from_raw_parts::<$int_type>(&input_handle, input.len(), $input_vectorization), + ArrayArg::from_raw_parts::<$out_type>(&output_handle, $expected.len(), $out_vectorization), + ) + }; + + let actual = client.read_one(output_handle.binding()); + let actual = $out_type::from_bytes(&actual); + + assert_eq!(actual, $expected); } )* } @@ -147,6 +238,55 @@ test_unary_impl!( ] ); +test_unary_impl_int_fixed!(test_count_ones, I, u32, I::count_ones, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![I: 0b1110_0010, 0b1000_0000, 0b1111_1111], + expected: &[4, 1, 8] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![I: 0b1110_0010, 0b1000_0000, 0b1111_1111, 0b1100_0001], + expected: &[4, 1, 8, 3] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![I: 0b1110_0010, 0b1000_0000, 0b1111_1111, 0b1100_0001], + expected: &[4, 1, 8, 3] + } +]); + +macro_rules! shift { + ($value:expr) => {{ + let shift = (size_of::() - 1) * 8; + $value << shift + }}; +} + +test_unary_impl_int!(test_reverse_bits, I, I::reverse_bits, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![I: 0b1110_0010, 0b1000_0000, 0b1111_1111], + expected: as_type![I: shift!(0b0100_0111), shift!(0b0000_0001), shift!(0b1111_1111)] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![I: 0b1110_0010, 0b1000_0000, 0b1111_1111, 0b1100_0001], + expected: as_type![I: shift!(0b0100_0111), shift!(0b0000_0001), shift!(0b1111_1111), shift!(0b1000_0011)] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![I: 0b1110_0010, 0b1000_0000, 0b1111_1111, 0b1100_0001], + expected: as_type![I: shift!(0b0100_0111), shift!(0b0000_0001), shift!(0b1111_1111), shift!(0b1000_0011)] + } +]); + #[allow(missing_docs)] #[macro_export] macro_rules! testgen_unary { @@ -171,3 +311,28 @@ macro_rules! testgen_unary { } }; } + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_unary_int { + () => { + mod unary_int { + use super::*; + + macro_rules! add_test { + ($test_name:ident) => { + #[test] + fn $test_name() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::unary::$test_name::( + client, + ); + } + }; + } + + add_test!(test_count_ones); + add_test!(test_reverse_bits); + } + }; +} diff --git a/crates/cubecl-core/tests/frontend/array.rs b/crates/cubecl-core/tests/frontend/array.rs deleted file mode 100644 index 5f66e2b9c..000000000 --- a/crates/cubecl-core/tests/frontend/array.rs +++ /dev/null @@ -1,196 +0,0 @@ -use cubecl::prelude::*; -use cubecl_core as cubecl; - -#[cube] -pub fn array_read_write(#[comptime] array_size: u32) { - let mut array = Array::::new(array_size); - array[0] = T::from_int(3); - let _a = array[0]; -} - -#[cube] -pub fn array_to_vectorized_variable() -> T { - let mut array = Array::::new(2); - array[0] = T::from_int(0); - array[1] = T::from_int(1); - array.to_vectorized(2) -} - -#[cube] -pub fn array_of_one_to_vectorized_variable() -> T { - let mut array = Array::::new(1); - array[0] = T::from_int(3); - array.to_vectorized(1) -} - -#[cube] -pub fn array_add_assign_simple(array: &mut Array) { - array[1] += 1; -} - -#[cube] -pub fn array_add_assign_expr(array: &mut Array) { - array[1 + 5] += 1; -} - -mod tests { - use pretty_assertions::assert_eq; - use std::num::NonZero; - - use super::*; - use cubecl_core::{ - cpa, - ir::{self, Item, Variable, VariableKind}, - }; - - type ElemType = f32; - - #[test] - fn cube_support_array() { - let mut context = CubeContext::default(); - - array_read_write::expand::(&mut context, 512); - assert_eq!( - context.into_scope().operations, - inline_macro_ref_read_write() - ) - } - - #[test] - fn array_add_assign() { - let mut context = CubeContext::default(); - let array = context.input(0, Item::new(u32::as_elem())); - - array_add_assign_simple::expand(&mut context, array.into()); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_array_add_assign_simple()); - } - - #[test] - fn cube_array_to_vectorized() { - let mut context = CubeContext::default(); - - array_to_vectorized_variable::expand::(&mut context); - assert_eq!( - context.into_scope().operations, - inline_macro_ref_to_vectorized() - ); - } - - #[test] - fn cube_array_of_one_to_vectorized() { - let mut context = CubeContext::default(); - - array_of_one_to_vectorized_variable::expand::(&mut context); - assert_eq!( - context.into_scope().operations, - inline_macro_ref_one_to_vectorized() - ); - } - - fn inline_macro_ref_read_write() -> Vec { - let context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let var = scope.create_local(item); - let pos: Variable = 0u32.into(); - - // Create - let array = scope.create_local_array(item, 512); - - // Write - cpa!(scope, array[pos] = 3.0_f32); - - // Read - cpa!(scope, var = array[pos]); - - scope.operations - } - - #[test] - fn array_add_assign_expr() { - let mut context = CubeContext::default(); - let array = context.input(0, Item::new(u32::as_elem())); - - array_add_assign_expr::expand(&mut context, array.into()); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_array_add_assign_expr()); - } - - fn inline_macro_array_add_assign_simple() -> Vec { - let context = CubeContext::default(); - - let mut scope = context.into_scope(); - let local = scope.create_local(Item::new(u32::as_elem())); - - let array = Variable::new(VariableKind::GlobalInputArray(0), Item::new(u32::as_elem())); - let index: Variable = 1u32.into(); - let value: Variable = 1u32.into(); - - cpa!(scope, local = array[index]); - cpa!(scope, local += value); - cpa!(scope, array[index] = local); - - scope.operations - } - - fn inline_macro_ref_to_vectorized() -> Vec { - let context = CubeContext::default(); - let scalar_item = Item::new(ElemType::as_elem()); - let vectorized_item = Item::vectorized(ElemType::as_elem(), NonZero::new(2)); - - let mut scope = context.into_scope(); - let pos0: Variable = 0u32.into(); - let pos1: Variable = 1u32.into(); - let array = scope.create_local_array(scalar_item, 2); - cpa!(scope, array[pos0] = 0.0_f32); - cpa!(scope, array[pos1] = 1.0_f32); - - let vectorized_var = scope.create_local(vectorized_item); - let tmp = scope.create_local(scalar_item); - cpa!(scope, tmp = array[pos0]); - cpa!(scope, vectorized_var[pos0] = tmp); - cpa!(scope, tmp = array[pos1]); - cpa!(scope, vectorized_var[pos1] = tmp); - - scope.operations - } - - fn inline_macro_ref_one_to_vectorized() -> Vec { - let context = CubeContext::default(); - let scalar_item = Item::new(ElemType::as_elem()); - let unvectorized_item = Item::vectorized(ElemType::as_elem(), NonZero::new(1)); - - let mut scope = context.into_scope(); - let pos0: Variable = 0u32.into(); - let array = scope.create_local_array(scalar_item, 1); - cpa!(scope, array[pos0] = 3.0_f32); - - let unvectorized_var = scope.create_local(unvectorized_item); - let tmp = scope.create_local(scalar_item); - cpa!(scope, tmp = array[pos0]); - cpa!(scope, unvectorized_var = tmp); - - scope.operations - } - - fn inline_macro_array_add_assign_expr() -> Vec { - let context = CubeContext::default(); - - let mut scope = context.into_scope(); - let local = scope.create_local(Item::new(u32::as_elem())); - - let array = Variable::new(VariableKind::GlobalInputArray(0), Item::new(u32::as_elem())); - let index: Variable = 6u32.into(); - let value: Variable = 1u32.into(); - - cpa!(scope, local = array[index]); - cpa!(scope, local += value); - cpa!(scope, array[index] = local); - - scope.operations - } -} diff --git a/crates/cubecl-core/tests/frontend/assign.rs b/crates/cubecl-core/tests/frontend/assign.rs deleted file mode 100644 index aeb2a4b98..000000000 --- a/crates/cubecl-core/tests/frontend/assign.rs +++ /dev/null @@ -1,185 +0,0 @@ -#![allow(unused)] - -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn mut_assign() { - let mut x: u32 = 0; - x += 1; -} - -#[cube] -pub fn mut_assign_input(y: u32) -> u32 { - let mut x = y; - x += 1; - y + 2 -} - -#[cube] -pub fn assign_mut_input(mut y: u32) -> u32 { - let x = y; - y += 1; - x + 2 -} - -#[cube] -pub fn assign_deref(y: &mut u32) -> u32 { - *y = 1; - *y -} - -#[derive(CubeType)] -struct StructWithComptime { - index: u32, - #[cube(comptime)] - tag: String, -} - -#[cube] -fn new_struct(index: u32, #[comptime] tag: String) -> StructWithComptime { - StructWithComptime { index, tag } -} - -mod tests { - use pretty_assertions::assert_eq; - use std::num::NonZero; - - use super::*; - use cubecl_core::{ - cpa, - ir::{Elem, Instruction, Item, Variable}, - }; - - #[test] - fn cube_mut_assign_test() { - let mut context = CubeContext::default(); - - mut_assign::expand(&mut context); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_ref_mut_assign()); - } - - #[test] - fn cube_mut_assign_input_test() { - let mut context = CubeContext::default(); - - let y = context.create_local_binding(Item::new(u32::as_elem())); - - mut_assign_input::expand(&mut context, y.into()); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_ref_mut_assign_input()); - } - - #[test] - fn cube_assign_mut_input_test() { - let mut context = CubeContext::default(); - - let y = context.create_local_binding(Item::new(u32::as_elem())); - - assign_mut_input::expand(&mut context, y.into()); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_ref_assign_mut_input()); - } - - #[test] - fn cube_assign_deref_test() { - let mut context = CubeContext::default(); - - let y = context.create_local_binding(Item::new(u32::as_elem())); - assign_deref::expand(&mut context, y.into()); - - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_ref_assign_deref()); - } - - fn inline_macro_ref_mut_assign() -> Vec { - let context = CubeContext::default(); - - let mut scope = context.into_scope(); - let x = scope.create_local(Item::new(u32::as_elem())); - - let zero: Variable = 0u32.into(); - let one: Variable = 1u32.into(); - - cpa!(scope, x = zero); - cpa!(scope, x = x + one); - - scope.operations - } - - fn inline_macro_ref_mut_assign_input() -> Vec { - let mut context = CubeContext::default(); - let item = Item::new(u32::as_elem()); - let y = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let y: Variable = y.into(); - let x = scope.create_local(item); - - let one: Variable = 1u32.into(); - let two: Variable = 2u32.into(); - - cpa!(scope, x = y); - cpa!(scope, x = x + one); - cpa!(scope, x = y + two); - - scope.operations - } - - fn inline_macro_ref_assign_mut_input() -> Vec { - let mut context = CubeContext::default(); - let item = Item::new(u32::as_elem()); - let y = context.create_local_variable(item); - println!("{:?}", y.index()); - - let mut scope = context.into_scope(); - let y: Variable = y.into(); - let x = scope.create_local(item); - - let one: Variable = 1u32.into(); - let two: Variable = 2u32.into(); - - cpa!(scope, x = y); - cpa!(scope, y = y + one); - cpa!(scope, x = x + two); - - scope.operations - } - - fn inline_macro_ref_assign_vectorized() -> Vec { - let mut context = CubeContext::default(); - let item = Item::vectorized(u32::as_elem(), NonZero::new(4)); - let y = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let y: Variable = y.into(); - let x = scope.create_local(item); - - let zero: Variable = 0u32.into(); - let one: Variable = 1u32.into(); - let two: Variable = 2u32.into(); - let three: Variable = 3u32.into(); - - cpa!(scope, x = cast(one)); - cpa!(scope, x = x + y); - - scope.operations - } - - fn inline_macro_ref_assign_deref() -> Vec { - let context = CubeContext::default(); - let mut scope = context.into_scope(); - let y = scope.create_local(Item::new(u32::as_elem())); - - let one: Variable = 1u32.into(); - - cpa!(scope, y = one); - - scope.operations - } -} diff --git a/crates/cubecl-core/tests/frontend/cast_elem.rs b/crates/cubecl-core/tests/frontend/cast_elem.rs deleted file mode 100644 index 1168ca6df..000000000 --- a/crates/cubecl-core/tests/frontend/cast_elem.rs +++ /dev/null @@ -1,258 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::{ - cube, - frontend::{Cast, Numeric}, -}; - -// From float -#[cube] -pub fn float_to_float(x: f32) { - let y = x + f32::from_int(2); - let _ = f32::cast_from(y) + f32::from_int(34); -} - -#[cube] -pub fn float_to_int(x: f32) { - let y = x + f32::from_int(2); - let _ = i32::cast_from(y) + i32::from_int(34); -} - -#[cube] -pub fn float_to_u32(x: f32) { - let y = x + f32::from_int(2); - let _ = u32::cast_from(y) + u32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn float_to_bool(x: f32) { - let y = x + f32::from_int(2); - let _ = bool::cast_from(y) || true; -} - -// From int -#[cube] -pub fn int_to_float(x: i32) { - let y = x + i32::from_int(2); - let _ = f32::cast_from(y) + f32::from_int(34); -} - -#[cube] -#[allow(clippy::useless_conversion)] -pub fn int_to_int(x: i32) { - let y = x + i32::from_int(2); - let _ = i32::cast_from(y) + i32::from_int(34); -} - -#[cube] -pub fn int_to_u32(x: i32) { - let y = x + i32::from_int(2); - let _ = u32::cast_from(y) + u32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn int_to_bool(x: i32) { - let y = x + i32::from_int(2); - let _ = bool::cast_from(y) || true; -} - -// // From u32 -#[cube] -pub fn u32_to_float(x: u32) { - let y = x + u32::from_int(2); - let _ = f32::cast_from(y) + f32::from_int(34); -} - -#[cube] -pub fn u32_to_int(x: u32) { - let y = x + u32::from_int(2); - let _ = i32::cast_from(y) + i32::from_int(34); -} - -#[cube] -#[allow(clippy::useless_conversion)] -pub fn u32_to_u32(x: u32) { - let y = x + u32::from_int(2); - let _ = u32::cast_from(y) + u32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn u32_to_bool(x: u32) { - let y = x + u32::from_int(2); - let _ = bool::cast_from(y) || true; -} - -// From bool -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_float(x: bool) { - let y = x && false; - let _ = f32::cast_from(y) + f32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_int(x: bool) { - let y = x && false; - let _ = i32::cast_from(y) + i32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_u32(x: bool) { - let y = x && false; - let _ = u32::cast_from(y) + u32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -#[allow(clippy::useless_conversion)] -pub fn bool_to_bool(x: bool) { - let y = x && false; - let _ = bool::cast_from(y) || true; -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - frontend::{CubeContext, CubePrimitive}, - ir::{Elem, Item, Variable}, - }; - - macro_rules! cast_test { - ($name:ident, $module:expr, $from:expr, $to:expr) => { - #[test] - fn $name() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding($from); - - $module(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_cast($from, $to) - ); - } - }; - } - - cast_test!( - cube_float_to_int_test, - float_to_int::expand, - Item::new(f32::as_elem()), - Item::new(i32::as_elem()) - ); - - cast_test!( - cube_float_to_u32_test, - float_to_u32::expand, - Item::new(f32::as_elem()), - Item::new(u32::as_elem()) - ); - - cast_test!( - cube_float_to_bool_test, - float_to_bool::expand, - Item::new(f32::as_elem()), - Item::new(Elem::Bool) - ); - - cast_test!( - cube_int_to_float_test, - int_to_float::expand, - Item::new(i32::as_elem()), - Item::new(f32::as_elem()) - ); - - cast_test!( - cube_int_to_u32_test, - int_to_u32::expand, - Item::new(i32::as_elem()), - Item::new(u32::as_elem()) - ); - - cast_test!( - cube_int_to_bool_test, - int_to_bool::expand, - Item::new(i32::as_elem()), - Item::new(Elem::Bool) - ); - - cast_test!( - cube_u32_to_float_test, - u32_to_float::expand, - Item::new(u32::as_elem()), - Item::new(f32::as_elem()) - ); - - cast_test!( - cube_u32_to_int_test, - u32_to_int::expand, - Item::new(u32::as_elem()), - Item::new(i32::as_elem()) - ); - - cast_test!( - cube_u32_to_bool_test, - u32_to_bool::expand, - Item::new(u32::as_elem()), - Item::new(Elem::Bool) - ); - - cast_test!( - cube_bool_to_float_test, - bool_to_float::expand, - Item::new(Elem::Bool), - Item::new(f32::as_elem()) - ); - - cast_test!( - cube_bool_to_int_test, - bool_to_int::expand, - Item::new(Elem::Bool), - Item::new(i32::as_elem()) - ); - - cast_test!( - cube_bool_to_u32_test, - bool_to_u32::expand, - Item::new(Elem::Bool), - Item::new(u32::as_elem()) - ); - - fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String { - let mut context = CubeContext::default(); - let x = context.create_local_variable(from_item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y = scope.create_local(to_item); - - match from_item.elem() { - Elem::Float(_) => cpa!(scope, x = x + 2f32), - Elem::Int(_) => cpa!(scope, x = x + 2i32), - Elem::AtomicInt(_) => cpa!(scope, x = x + 2i32), - Elem::UInt(_) => cpa!(scope, x = x + 2u32), - Elem::AtomicUInt(_) => cpa!(scope, x = x + 2u32), - Elem::Bool => cpa!(scope, x = x && false), - } - - cpa!(scope, y = cast(x)); - - match to_item.elem() { - Elem::Float(_) => cpa!(scope, y = y + 34f32), - Elem::Int(_) => cpa!(scope, y = y + 34i32), - Elem::AtomicInt(_) => cpa!(scope, y = y + 34i32), - Elem::UInt(_) => cpa!(scope, y = y + 34u32), - Elem::AtomicUInt(_) => cpa!(scope, y = y + 34u32), - Elem::Bool => cpa!(scope, y = y || true), - } - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/cast_kind.rs b/crates/cubecl-core/tests/frontend/cast_kind.rs deleted file mode 100644 index 4184a33a2..000000000 --- a/crates/cubecl-core/tests/frontend/cast_kind.rs +++ /dev/null @@ -1,128 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::{ - cube, - frontend::{Cast, Float, Int, Numeric}, -}; - -#[cube] -pub fn cast_float_kind(input: F1) { - let x = input + F1::new(5.9); - let y = F2::cast_from(x); - let _ = y + F2::new(2.3); -} - -#[cube] -pub fn cast_int_kind(input: I1) { - let x = input + I1::new(5); - let y = I2::cast_from(x); - let _ = y + I2::new(2); -} - -#[cube] -pub fn cast_numeric_to_kind(input: T) { - let x = input + T::from_int(5); - let y = I::cast_from(x); - let _ = y + I::from_int(2); -} - -#[cube] -pub fn cast_int_to_numeric(input: I) { - let x = input + I::from_int(5); - let y = T::cast_from(x); - let _ = y + T::from_int(2); -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - frontend::{CubeContext, CubePrimitive}, - ir::{Item, Variable}, - }; - - #[test] - fn cube_cast_float_kind_test() { - let mut context = CubeContext::default(); - let item = Item::new(f64::as_elem()); - - let input = context.create_local_binding(item); - - cast_float_kind::expand::(&mut context, input.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); - } - - #[test] - fn cube_cast_int_kind_test() { - let mut context = CubeContext::default(); - let item = Item::new(i32::as_elem()); - - let input = context.create_local_binding(item); - - cast_int_kind::expand::(&mut context, input.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - #[test] - fn cube_cast_numeric_kind_test() { - let mut context = CubeContext::default(); - let item = Item::new(i32::as_elem()); - - let input = context.create_local_binding(item); - - cast_numeric_to_kind::expand::(&mut context, input.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - #[test] - fn cube_cast_kind_numeric_test() { - let mut context = CubeContext::default(); - let item = Item::new(i32::as_elem()); - - let input = context.create_local_binding(item); - - cast_int_to_numeric::expand::(&mut context, input.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - fn inline_macro_ref_float() -> String { - let mut context = CubeContext::default(); - let float_64 = Item::new(f64::as_elem()); - let float_32 = Item::new(f32::as_elem()); - let input = context.create_local_binding(float_64); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let y = scope.create_local(float_32); - - cpa!(scope, input = input + 5.9f32 as f64); - cpa!(scope, y = cast(input)); - cpa!(scope, y = y + 2.3f32); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_int() -> String { - let mut context = CubeContext::default(); - let int_32 = Item::new(i32::as_elem()); - let int_64 = Item::new(i64::as_elem()); - let input = context.create_local_binding(int_32); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let y = scope.create_local(int_64); - - cpa!(scope, input = input + 5i32); - cpa!(scope, y = cast(input)); - cpa!(scope, y = y + 2i64); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/comptime.rs b/crates/cubecl-core/tests/frontend/comptime.rs deleted file mode 100644 index 842bf041e..000000000 --- a/crates/cubecl-core/tests/frontend/comptime.rs +++ /dev/null @@ -1,390 +0,0 @@ -use cubecl_core::prelude::*; -use cubecl_core::{self as cubecl, comptime}; - -#[derive(Clone)] -pub struct State { - cond: bool, - bound: u32, -} - -impl Init for State { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -#[cube] -pub fn comptime_if_else(lhs: T, #[comptime] cond: bool) { - if cond { - let _ = lhs + T::from_int(4); - } else { - let _ = lhs - T::from_int(5); - } -} - -#[cube] -#[allow(clippy::collapsible_else_if)] -pub fn comptime_else_then_if(lhs: T, #[comptime] cond1: bool, #[comptime] cond2: bool) { - if cond1 { - let _ = lhs + T::from_int(4); - } else { - if cond2 { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } - } -} - -#[cube] -pub fn comptime_float() { - let comptime_float = 0.0f32; - let _runtime_float = comptime_float.runtime(); -} - -#[cube] -pub fn comptime_elsif(lhs: T, #[comptime] cond1: bool, #[comptime] cond2: bool) { - if cond1 { - let _ = lhs + T::from_int(4); - } else if cond2 { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } -} - -#[cube] -pub fn comptime_elsif_with_runtime1(lhs: T, #[comptime] comptime_cond: bool) { - let runtime_cond = lhs >= T::from_int(2); - if comptime_cond { - let _ = lhs + T::from_int(4); - } else if runtime_cond { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } -} - -#[cube] -pub fn comptime_elsif_with_runtime2(lhs: T, #[comptime] comptime_cond: bool) { - let runtime_cond = lhs >= T::from_int(2); - if runtime_cond { - let _ = lhs + T::from_int(4); - } else if comptime_cond { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } -} - -#[cube] -pub fn comptime_if_expr(lhs: T, #[comptime] x: u32, #[comptime] y: u32) { - let y2 = x + y; - - if x < y2 { - let _ = lhs + T::from_int(4); - } else { - let _ = lhs - T::from_int(5); - } -} - -#[cube] -pub fn comptime_with_map_bool(#[comptime] state: State) -> T { - let cond = state.cond; - - let mut x = T::from_int(3); - if cond { - x += T::from_int(4); - } else { - x -= T::from_int(4); - } - x -} - -#[cube] -pub fn comptime_with_map_uint(#[comptime] state: State) -> T { - let bound = state.bound; - - let mut x = T::from_int(3); - #[unroll] - for _ in 0..bound { - x += T::from_int(4); - } - - x -} - -fn rust_function(input: u32) -> u32 { - input + 2 -} - -#[cube] -pub fn comptime_block(a: T) -> T { - let comptime_val = comptime! { rust_function(2) as i64 }; - - a + T::from_int(comptime_val) -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - frontend::{CubeContext, CubePrimitive}, - ir::{Elem, Item, Variable}, - }; - use pretty_assertions::assert_eq; - - type ElemType = f32; - - #[test] - fn cube_comptime_if_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - comptime_if_else::expand::(&mut context, lhs.into(), true); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime(true) - ); - } - - #[test] - fn cube_comptime_if_numeric_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - comptime_if_expr::expand::(&mut context, lhs.into(), 4, 5); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime(true) - ); - } - - #[test] - fn cube_comptime_else_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - comptime_if_else::expand::(&mut context, lhs.into(), false); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime2(false) - ); - } - - #[test] - fn cube_comptime_elsif_test() { - for cond1 in [false, true] { - for cond2 in [false, true] { - let mut context1 = CubeContext::default(); - let lhs = context1.create_local_binding(Item::new(ElemType::as_elem())); - comptime_else_then_if::expand::(&mut context1, lhs.into(), cond1, cond2); - let scope1 = context1.into_scope(); - - let mut context2 = CubeContext::default(); - let lhs = context2.create_local_binding(Item::new(ElemType::as_elem())); - comptime_elsif::expand::(&mut context2, lhs.into(), cond1, cond2); - let scope2 = context2.into_scope(); - - assert_eq!( - format!("{:?}", scope1.operations), - format!("{:?}", scope2.operations), - ); - } - } - } - - #[test] - fn cube_comptime_elsif_runtime1_test() { - for cond in [false, true] { - let mut context = CubeContext::default(); - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime1::expand::(&mut context, lhs.into(), cond); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_elsif_runtime1(cond) - ); - } - } - - #[test] - fn cube_comptime_elsif_runtime2_test() { - for cond in [false, true] { - let mut context = CubeContext::default(); - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - comptime_elsif_with_runtime2::expand::(&mut context, lhs.into(), cond); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_elsif_runtime2(cond) - ); - } - } - - #[test] - fn cube_comptime_map_bool_test() { - let mut context1 = CubeContext::default(); - let mut context2 = CubeContext::default(); - - let comptime_state_true = State { - cond: true, - bound: 4, - }; - let comptime_state_false = State { - cond: false, - bound: 4, - }; - - comptime_with_map_bool::expand::(&mut context1, comptime_state_true); - comptime_with_map_bool::expand::(&mut context2, comptime_state_false); - - let scope1 = context1.into_scope(); - let scope2 = context2.into_scope(); - - assert_ne!( - format!("{:?}", scope1.operations), - format!("{:?}", scope2.operations) - ); - } - - #[test] - fn cube_comptime_map_uint_test() { - let mut context = CubeContext::default(); - - let comptime_state = State { - cond: true, - bound: 4, - }; - - comptime_with_map_uint::expand::(&mut context, comptime_state); - - let scope = context.into_scope(); - - assert!(!format!("{:?}", scope.operations).contains("RangeLoop")); - } - - #[test] - fn cube_comptime_block_test() { - let mut context = CubeContext::default(); - - let a = context.create_local_binding(Item::new(ElemType::as_elem())); - - comptime_block::expand::(&mut context, a.into()); - - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime_block() - ); - } - - fn inline_macro_ref_comptime(cond: bool) -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y = scope.create_local(item); - - if cond { - cpa!(scope, y = x + 4.0f32); - } else { - cpa!(scope, y = x - 5.0f32); - }; - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_comptime2(cond: bool) -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - if cond { - cpa!(scope, x = x + 4.0f32); - } else { - cpa!(scope, x = x - 5.0f32); - }; - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let runtime_cond = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - cpa!(scope, runtime_cond = x >= 2.0f32); - - if comptime_cond { - cpa!(scope, y = x + 4.0f32); - } else { - cpa!(&mut scope, if(runtime_cond).then(|scope| { - cpa!(scope, y = x + 5.0f32); - }).else(|scope| { - cpa!(scope, y = x - 6.0f32); - })); - }; - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let runtime_cond = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - cpa!(scope, runtime_cond = x >= 2.0f32); - - cpa!(&mut scope, if(runtime_cond).then(|scope| { - cpa!(scope, y = x + 4.0f32); - }).else(|scope| { - if comptime_cond { - cpa!(scope, y = x + 5.0f32); - } else { - cpa!(scope, y = x - 6.0f32); - } - })); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_comptime_block() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let a = context.create_local_variable(item); - let comptime_var: Variable = ElemType::from_int(4).into(); - - let mut scope = context.into_scope(); - let x: Variable = a.into(); - cpa!(scope, x = x + comptime_var); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/constants.rs b/crates/cubecl-core/tests/frontend/constants.rs deleted file mode 100644 index 1026101c4..000000000 --- a/crates/cubecl-core/tests/frontend/constants.rs +++ /dev/null @@ -1,114 +0,0 @@ -use half::{bf16, f16}; -use paste::paste; - -use cubecl_core::{self as cubecl, prelude::*}; - -macro_rules! gen_cube { - ($trait:ident, [ $($constant:ident $(| $ret_type:ty)?),* ]) => { - $( - gen_cube!($trait, $constant, $($ret_type)?); - )* - }; - ($trait:ident, $constant:ident,) => { - gen_cube!($trait, $constant, T); - }; - ($trait:ident, $constant:ident, $ret_type:ty) => { - paste! { - gen_cube!([< $trait:lower _ $constant:lower >], $trait, $constant, $ret_type); - } - }; - ($func_name:ident, $trait:ident, $constant:ident, $ret_type:ty) => { - #[cube] - pub fn $func_name() -> $ret_type { - T::$constant - } - }; -} - -macro_rules! gen_tests { - ($trait:ident, [ $($type:ident),* ], $constants:tt) => { - $( - gen_tests!($trait, $type, $constants); - )* - }; - ($trait:ident, $type:ident, [ $($constant:ident $(| $ret_type:ty)?),* ]) => { - $( - gen_tests!($trait, $type, $constant, $($ret_type)?); - )* - }; - ($trait:ident, $type:ident, $constant:ident,) => { - gen_tests!($trait, $type, $constant, $type); - }; - ($trait:ident, $type:ident, $constant:ident, $ret_type:ty) => { - paste! { - gen_tests!([< cube_ $trait:lower _ $constant:lower _ $type _test >], [< $trait:lower _ $constant:lower >], $type, $constant, $ret_type); - } - }; - ($test_name:ident, $func_name:ident, $type:ty, $constant:ident, $ret_type:ty) => { - #[test] - fn $test_name() { - let mut context = CubeContext::default(); - $func_name::expand::<$type>(&mut context); - let scope = context.into_scope(); - - let mut scope1 = CubeContext::default().into_scope(); - let item = Item::new(<$ret_type>::as_elem()); - scope1.create_with_value(<$type>::$constant, item); - - assert_eq!( - format!("{:?}", scope.operations), - format!("{:?}", scope1.operations) - ); - } - }; -} - -gen_cube!(Numeric, [MAX, MIN]); -gen_cube!(Int, [BITS | u32]); -gen_cube!( - Float, - [ - DIGITS | u32, - EPSILON, - INFINITY, - MANTISSA_DIGITS | u32, - MAX_10_EXP | i32, - MAX_EXP | i32, - MIN_10_EXP | i32, - MIN_EXP | i32, - MIN_POSITIVE, - NAN, - NEG_INFINITY, - RADIX | u32 - ] -); - -mod tests { - use super::*; - use cubecl_core::{ - frontend::{CubeContext, CubePrimitive}, - ir::Item, - }; - use pretty_assertions::assert_eq; - - gen_tests!(Numeric, [bf16, f16, f32, f64, i32, i64, u32], [MAX, MIN]); - gen_tests!(Int, [i32, i64, u32], [BITS | u32]); - gen_tests!( - Float, - [bf16, f16, f32, f64], - [ - DIGITS | u32, - EPSILON, - INFINITY, - MANTISSA_DIGITS | u32, - MAX_10_EXP | i32, - MAX_EXP | i32, - MIN_10_EXP | i32, - MIN_EXP | i32, - MIN_POSITIVE, - NAN, - NEG_INFINITY, - RADIX | u32 - ] - ); -} diff --git a/crates/cubecl-core/tests/frontend/cube_impl.rs b/crates/cubecl-core/tests/frontend/cube_impl.rs deleted file mode 100644 index 41bcf06b3..000000000 --- a/crates/cubecl-core/tests/frontend/cube_impl.rs +++ /dev/null @@ -1,103 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[derive(CubeType)] -struct SimpleType { - a: u32, -} - -#[cube] -impl SimpleType { - #[allow(dead_code)] - fn simple_method(&self, lhs: u32) -> u32 { - self.a * lhs - } - - #[allow(dead_code)] - pub fn call_method_inner(&self) -> u32 { - self.simple_method(5u32) - } - - #[allow(dead_code)] - pub fn call_method_as_function_inner(&self) -> u32 { - Self::simple_method(self, 5u32) - } - - #[allow(dead_code)] - pub fn return_self(self) -> Self { - self - } - - #[allow(dead_code)] - pub fn with_other(self, other: Self) -> u32 { - self.call_method_inner() + other.call_method_inner() - } - - #[allow(dead_code)] - pub fn with_generic(self, rhs: E) -> u32 { - self.simple_method(u32::cast_from(rhs)) - } -} - -#[derive(CubeType)] -struct TypeGeneric { - a: C, -} - -#[cube] -impl TypeGeneric { - #[allow(dead_code)] - fn value(&self, lhs: u32) -> C { - self.a * C::cast_from(lhs) - } - - #[allow(dead_code)] - pub fn call_inner(&self) -> C { - let val1 = self.value(5u32); - let val2 = Self::value(self, 2u32); - val1 + val2 - } -} - -#[derive(CubeType)] -struct ComplexType { - a: C, - t: T, -} - -#[cube] -impl ComplexType { - #[allow(dead_code)] - pub fn complex_method(&mut self, lhs: f32, rhs: C) -> f32 { - let tmp = self.a + (C::cast_from(lhs) / rhs); - - Self::simple_function(lhs, tmp) - } - - fn simple_function(lhs: f32, rhs: C) -> f32 { - lhs * f32::cast_from(rhs) - } -} - -mod foo { - use super::*; - - #[derive(CubeType)] - pub struct TypeInModule { - pub a: u32, - } - - #[cube] - impl TypeInModule { - #[allow(dead_code)] - pub fn simple_method(&self, lhs: u32) -> u32 { - self.a * lhs - } - } -} - -#[cube] -fn call_from_outside_module() { - let bar = foo::TypeInModule { a: 0u32 }; - let _ = bar.simple_method(5); -} diff --git a/crates/cubecl-core/tests/frontend/cube_trait.rs b/crates/cubecl-core/tests/frontend/cube_trait.rs deleted file mode 100644 index d9dfa3208..000000000 --- a/crates/cubecl-core/tests/frontend/cube_trait.rs +++ /dev/null @@ -1,114 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -trait FunctionGeneric { - #[allow(unused)] - fn test(lhs: C, rhs: C) -> C; -} - -#[cube] -trait TraitGeneric { - #[allow(unused)] - fn test(lhs: C, rhs: C) -> C; -} - -#[cube] -trait CombinedTraitFunctionGeneric { - #[allow(unused)] - fn test(lhs: C, rhs: C) -> O; -} - -struct Test; - -#[cube] -impl FunctionGeneric for Test { - fn test(lhs: C, rhs: C) -> C { - lhs + rhs - } -} - -#[cube] -impl TraitGeneric for Test { - fn test(lhs: C, rhs: C) -> C { - lhs + rhs - } -} - -#[cube] -impl CombinedTraitFunctionGeneric for Test { - fn test(lhs: C, rhs: C) -> O { - O::cast_from(lhs + rhs) - } -} - -#[cube] -pub fn simple(lhs: C, rhs: C) -> C { - lhs + rhs -} - -#[cube] -pub fn with_cast(lhs: C, rhs: C) -> O { - O::cast_from(lhs + rhs) -} - -mod tests { - use cubecl_core::ir::{Item, Scope}; - - use super::*; - - #[test] - fn test_function_generic() { - let mut context = CubeContext::default(); - let lhs = context.create_local_binding(Item::new(f32::as_elem())); - let rhs = context.create_local_binding(Item::new(f32::as_elem())); - - ::__expand_test::(&mut context, lhs.into(), rhs.into()); - - assert_eq!(simple_scope(), context.into_scope()); - } - - #[test] - fn test_trait_generic() { - let mut context = CubeContext::default(); - let lhs = context.create_local_binding(Item::new(f32::as_elem())); - let rhs = context.create_local_binding(Item::new(f32::as_elem())); - - >::__expand_test(&mut context, lhs.into(), rhs.into()); - - assert_eq!(simple_scope(), context.into_scope()); - } - - #[test] - fn test_combined_function_generic() { - let mut context = CubeContext::default(); - let lhs = context.create_local_binding(Item::new(f32::as_elem())); - let rhs = context.create_local_binding(Item::new(f32::as_elem())); - - >::__expand_test::( - &mut context, - lhs.into(), - rhs.into(), - ); - - assert_eq!(with_cast_scope(), context.into_scope()); - } - - fn simple_scope() -> Scope { - let mut context_ref = CubeContext::default(); - let lhs = context_ref.create_local_binding(Item::new(f32::as_elem())); - let rhs = context_ref.create_local_binding(Item::new(f32::as_elem())); - - simple::expand::(&mut context_ref, lhs.into(), rhs.into()); - context_ref.into_scope() - } - - fn with_cast_scope() -> Scope { - let mut context_ref = CubeContext::default(); - let lhs = context_ref.create_local_binding(Item::new(f32::as_elem())); - let rhs = context_ref.create_local_binding(Item::new(f32::as_elem())); - - with_cast::expand::(&mut context_ref, lhs.into(), rhs.into()); - context_ref.into_scope() - } -} diff --git a/crates/cubecl-core/tests/frontend/enum_type.rs b/crates/cubecl-core/tests/frontend/enum_type.rs deleted file mode 100644 index 549c085a6..000000000 --- a/crates/cubecl-core/tests/frontend/enum_type.rs +++ /dev/null @@ -1,31 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[allow(dead_code)] -mod test_compilation { - use super::*; - - #[derive(CubeType)] - enum VariantNoInput { - Add, - Min, - } - - #[derive(CubeType)] - enum SingleVariant { - Add(u32), - } - - #[derive(CubeType)] - enum MultipleVariants { - Add(u32), - Min(u32, u32), - } - - #[derive(CubeType)] - enum MultipleVariantsNamed { - Add(u32), - Min(u32, u32), - Mul { lhs: u32, rhs: u32 }, - } -} diff --git a/crates/cubecl-core/tests/frontend/for_loop.rs b/crates/cubecl-core/tests/frontend/for_loop.rs deleted file mode 100644 index 70748968f..000000000 --- a/crates/cubecl-core/tests/frontend/for_loop.rs +++ /dev/null @@ -1,136 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::{ - cube, - frontend::{Array, CubeContext, CubePrimitive, Float}, -}; - -type ElemType = f32; - -#[cube] -pub fn for_loop(mut lhs: Array, rhs: F, end: u32, #[comptime] unroll: bool) { - let tmp1 = rhs * rhs; - let tmp2 = tmp1 + rhs; - - #[unroll(unroll)] - for i in 0..end { - lhs[i] = tmp2 + lhs[i]; - } -} - -#[cube] -pub fn for_in_loop(input: &Array) -> F { - let mut sum = F::new(0.0); - - for item in input { - sum += item; - } - sum -} - -mod tests { - use cubecl::frontend::ExpandElement; - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_for_loop_with_unroll() { - let mut context = CubeContext::default(); - let unroll = true; - - let lhs = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); - let rhs = context.create_local_binding(Item::new(ElemType::as_elem())); - let end: ExpandElement = 4u32.into(); - - for_loop::expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref(unroll)); - } - - #[test] - fn test_for_loop_no_unroll() { - let mut context = CubeContext::default(); - let unroll = false; - - let lhs = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); - let rhs = context.create_local_binding(Item::new(ElemType::as_elem())); - let end: ExpandElement = 4u32.into(); - - for_loop::expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref(unroll)); - } - - #[test] - fn test_for_in_loop() { - let mut context = CubeContext::default(); - - let input = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); - - for_in_loop::expand::(&mut context, input.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_for_in() - ); - } - - fn inline_macro_ref(unroll: bool) -> String { - let context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let lhs = scope.create_local_array(item, 4u32); - let rhs = scope.create_local(item); - let end = 4u32; - - // Kernel - let tmp1 = scope.create_local(item); - cpa!(scope, tmp1 = rhs * rhs); - cpa!(scope, tmp1 = tmp1 + rhs); - - cpa!( - &mut scope, - range(0u32, end, unroll).for_each(|i, scope| { - cpa!(scope, rhs = lhs[i]); - cpa!(scope, rhs = tmp1 + rhs); - cpa!(scope, lhs[i] = rhs); - }) - ); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_for_in() -> String { - let context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let input = scope.create_local_array(item, 4u32); - let sum = scope.create_local(item); - let end = scope.create_local(Item::new(u32::as_elem())); - let zero: Variable = ElemType::new(0.0).into(); - - // Kernel - let tmp1 = scope.create_local(item); - cpa!(scope, sum = zero); - cpa!(scope, end = len(input)); - - cpa!( - &mut scope, - range(0u32, end).for_each(|i, scope| { - cpa!(scope, tmp1 = input[i]); - cpa!(scope, sum = sum + tmp1); - }) - ); - - format!("{:#?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/function_call.rs b/crates/cubecl-core/tests/frontend/function_call.rs deleted file mode 100644 index fab9ccb11..000000000 --- a/crates/cubecl-core/tests/frontend/function_call.rs +++ /dev/null @@ -1,111 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::{cube, frontend::Numeric}; - -#[cube] -pub fn caller_no_arg(x: u32) { - let _ = x + callee_no_arg(); -} - -#[cube] -pub fn callee_no_arg() -> u32 { - u32::from_int(8) -} - -#[cube] -pub fn no_call_no_arg(x: u32) { - let _ = x + u32::from_int(8); -} - -#[cube] -pub fn caller_with_arg(x: u32) { - let _ = x + callee_with_arg(x); -} - -#[cube] -pub fn callee_with_arg(x: u32) -> u32 { - x * u32::from_int(8) -} - -#[cube] -pub fn no_call_with_arg(x: u32) { - let _ = x + x * u32::from_int(8); -} - -#[cube] -pub fn caller_with_generics(x: T) { - let _ = x + callee_with_generics::(x); -} - -#[cube] -pub fn callee_with_generics(x: T) -> T { - x * T::from_int(8) -} - -#[cube] -pub fn no_call_with_generics(x: T) { - let _ = x + x * T::from_int(8); -} - -mod tests { - use super::*; - use cubecl_core::{ - frontend::{CubeContext, CubePrimitive}, - ir::Item, - }; - - #[test] - fn cube_call_equivalent_to_no_call_no_arg_test() { - let mut caller_context = CubeContext::default(); - let x = caller_context.create_local_binding(Item::new(u32::as_elem())); - caller_no_arg::expand(&mut caller_context, x.into()); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::default(); - let x = no_call_context.create_local_binding(Item::new(u32::as_elem())); - no_call_no_arg::expand(&mut no_call_context, x.into()); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } - - #[test] - fn cube_call_equivalent_to_no_call_with_arg_test() { - let mut caller_context = CubeContext::default(); - - let x = caller_context.create_local_binding(Item::new(u32::as_elem())); - caller_with_arg::expand(&mut caller_context, x.into()); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::default(); - let x = no_call_context.create_local_binding(Item::new(u32::as_elem())); - no_call_with_arg::expand(&mut no_call_context, x.into()); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } - - #[test] - fn cube_call_equivalent_to_no_call_with_generics_test() { - let mut caller_context = CubeContext::default(); - type ElemType = i64; - let x = caller_context.create_local_binding(Item::new(ElemType::as_elem())); - caller_with_generics::expand::(&mut caller_context, x.into()); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::default(); - let x = no_call_context.create_local_binding(Item::new(ElemType::as_elem())); - no_call_with_generics::expand::(&mut no_call_context, x.into()); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } -} diff --git a/crates/cubecl-core/tests/frontend/generic_kernel.rs b/crates/cubecl-core/tests/frontend/generic_kernel.rs deleted file mode 100644 index 29ede3cfc..000000000 --- a/crates/cubecl-core/tests/frontend/generic_kernel.rs +++ /dev/null @@ -1,65 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::{cube, frontend::Numeric}; - -#[cube] -pub fn generic_kernel(lhs: T) { - let _ = lhs + T::from_int(5); -} - -mod tests { - use cubecl_core::{ - cpa, - frontend::{CubeContext, CubePrimitive}, - ir::{Item, Variable}, - }; - - use super::*; - - #[test] - fn cube_generic_float_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(f32::as_elem())); - - generic_kernel::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); - } - - #[test] - fn cube_generic_int_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(i32::as_elem())); - - generic_kernel::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - fn inline_macro_ref_float() -> String { - let mut context = CubeContext::default(); - let item = Item::new(f32::as_elem()); - let var = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let var: Variable = var.into(); - cpa!(scope, var = var + 5.0f32); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_int() -> String { - let mut context = CubeContext::default(); - let item = Item::new(i32::as_elem()); - let var = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let var: Variable = var.into(); - cpa!(scope, var = var + 5); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/if.rs b/crates/cubecl-core/tests/frontend/if.rs deleted file mode 100644 index 15924ba05..000000000 --- a/crates/cubecl-core/tests/frontend/if.rs +++ /dev/null @@ -1,210 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn if_greater(lhs: T) { - if lhs > T::from_int(0) { - let _ = lhs + T::from_int(4); - } -} - -#[cube] -pub fn if_greater_var(lhs: T) { - let x = lhs > T::from_int(0); - if x { - let _ = lhs + T::from_int(4); - } -} - -#[cube] -pub fn if_then_else(lhs: F) { - if lhs < F::from_int(0) { - let _ = lhs + F::from_int(4); - } else { - let _ = lhs - F::from_int(5); - } -} - -#[cube] -pub fn elsif(lhs: F) { - if lhs < F::new(0.) { - let _ = lhs + F::new(2.); - } else if lhs > F::new(0.) { - let _ = lhs + F::new(1.); - } else { - let _ = lhs + F::new(0.); - } -} - -#[cube] -pub fn elsif_assign(lhs: F) { - let _ = if lhs < F::new(0.) { - lhs + F::new(2.) - } else if lhs > F::new(0.) { - lhs + F::new(1.) - } else { - lhs + F::new(0.) - }; -} - -mod tests { - use cubecl_core::{ - cpa, - frontend::{CubeContext, CubePrimitive}, - ir::{Elem, Item, Variable}, - }; - use pretty_assertions::assert_eq; - - use super::*; - - type ElemType = f32; - - #[test] - fn cube_if_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - if_greater::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref_if()); - } - - #[test] - fn cube_if_else_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - if_then_else::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_if_else() - ); - } - - #[test] - fn cube_elsif_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - elsif::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref_elsif()); - } - - #[test] - fn cube_elsif_assign_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - elsif_assign::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_elsif_assign() - ); - } - - fn inline_macro_ref_if() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - - cpa!(scope, cond = lhs > 0f32); - cpa!(&mut scope, if(cond).then(|scope| { - cpa!(scope, lhs = lhs + 4.0f32); - })); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_if_else() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - let y = scope.create_local(item); - - cpa!(scope, cond = lhs < 0f32); - cpa!(&mut scope, if(cond).then(|scope| { - cpa!(scope, y = lhs + 4.0f32); - }).else(|scope|{ - cpa!(scope, y = lhs - 5.0f32); - })); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_elsif() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond1 = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - let y = scope.create_local(item); - let cond2 = scope.create_local(Item::new(Elem::Bool)); - - cpa!(scope, cond1 = lhs < 0f32); - cpa!(&mut scope, if(cond1).then(|scope| { - cpa!(scope, y = lhs + 2.0f32); - }).else(|mut scope|{ - cpa!(scope, cond2 = lhs > 0f32); - cpa!(&mut scope, if(cond2).then(|scope| { - cpa!(scope, y = lhs + 1.0f32); - }).else(|scope|{ - cpa!(scope, y = lhs + 0.0f32); - })); - })); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_elsif_assign() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let lhs: Variable = lhs.into(); - let cond1 = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - let out = scope.create_local(item); - let cond2 = scope.create_local(Item::new(Elem::Bool)); - let out2 = scope.create_local(item); - - cpa!(scope, cond1 = lhs < 0f32); - cpa!(&mut scope, if(cond1).then(|scope| { - cpa!(scope, y = lhs + 2.0f32); - cpa!(scope, out = y); - }).else(|mut scope|{ - cpa!(scope, cond2 = lhs > 0f32); - cpa!(&mut scope, if(cond2).then(|scope| { - cpa!(scope, y = lhs + 1.0f32); - cpa!(scope, out2 = y); - }).else(|scope|{ - cpa!(scope, y = lhs + 0.0f32); - cpa!(scope, out2 = y); - })); - cpa!(scope, out = out2); - })); - - format!("{:#?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/kernel.rs b/crates/cubecl-core/tests/frontend/kernel.rs deleted file mode 100644 index f82b2562a..000000000 --- a/crates/cubecl-core/tests/frontend/kernel.rs +++ /dev/null @@ -1,9 +0,0 @@ -use cubecl::prelude::*; -use cubecl_core as cubecl; - -#[cube(launch)] -pub fn with_kernel(kernel: &mut Array) { - if ABSOLUTE_POS > kernel.len() { - kernel[ABSOLUTE_POS] = F::cast_from(5.0); - } -} diff --git a/crates/cubecl-core/tests/frontend/literal.rs b/crates/cubecl-core/tests/frontend/literal.rs deleted file mode 100644 index f2db2a981..000000000 --- a/crates/cubecl-core/tests/frontend/literal.rs +++ /dev/null @@ -1,58 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn literal(lhs: F) { - let _ = lhs + F::from_int(5); -} - -#[cube] -pub fn literal_float_no_decimals(lhs: F) { - let _ = lhs + F::new(5.); -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = f32; - - #[test] - fn cube_literal_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - literal::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); - } - - #[test] - fn cube_literal_float_no_decimal_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - literal_float_no_decimals::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let lhs: Variable = lhs.into(); - cpa!(scope, lhs = lhs + 5.0f32); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/loop.rs b/crates/cubecl-core/tests/frontend/loop.rs deleted file mode 100644 index 39e77ae3c..000000000 --- a/crates/cubecl-core/tests/frontend/loop.rs +++ /dev/null @@ -1,136 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn while_not(lhs: I) { - while lhs != I::from_int(0) { - let _ = lhs % I::from_int(1); - } -} - -#[cube] -pub fn manual_loop_break(lhs: I) { - loop { - if lhs == I::from_int(0) { - break; - } - let _ = lhs % I::from_int(1); - } -} - -#[cube] -pub fn loop_with_return(lhs: I) { - loop { - if lhs == I::from_int(0) { - return; - } - let _ = lhs % I::from_int(1); - } -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Branch, Elem, Item, Variable}, - }; - use pretty_assertions::assert_eq; - - type ElemType = i32; - - #[test] - fn cube_while_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - while_not::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref_while()); - } - - #[test] - fn cube_loop_break_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - manual_loop_break::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_loop(false) - ); - } - - #[test] - fn cube_loop_with_return_test() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::new(ElemType::as_elem())); - - loop_with_return::expand::(&mut context, lhs.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_loop(true) - ); - } - - fn inline_macro_ref_while() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - let lhs: Variable = lhs.into(); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = lhs != 0); - cpa!(scope, cond = !cond); - cpa!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break) - })); - // Must not mutate `lhs` because it is used in every iteration - cpa!(scope, y = lhs % 1i32); - }) - ); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_loop(is_return: bool) -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - let lhs: Variable = lhs.into(); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = lhs == 0); - cpa!(scope, if(cond).then(|scope|{ - match is_return { - true => scope.register(Branch::Return), - false => scope.register(Branch::Break) - } - })); - // Must not mutate `lhs` because it is used in every iteration - cpa!(scope, y = lhs % 1i32); - }) - ); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/mod.rs b/crates/cubecl-core/tests/frontend/mod.rs deleted file mode 100644 index 8fc01cabc..000000000 --- a/crates/cubecl-core/tests/frontend/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -mod array; -mod assign; -mod cast_elem; -mod cast_kind; -mod comptime; -mod constants; -mod cube_impl; -mod cube_trait; -mod enum_type; -mod for_loop; -mod function_call; -mod generic_kernel; -mod r#if; -mod kernel; -mod literal; -mod r#loop; -mod module_import; -mod ops; -mod parenthesis; -mod redeclare; -mod reuse; -mod shared_memory; -mod r#struct; -mod tensor; -mod topology; -mod r#trait; -mod tuple; -mod vectorization; diff --git a/crates/cubecl-core/tests/frontend/module_import.rs b/crates/cubecl-core/tests/frontend/module_import.rs deleted file mode 100644 index 1d656ed48..000000000 --- a/crates/cubecl-core/tests/frontend/module_import.rs +++ /dev/null @@ -1,50 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -mod elsewhere { - use super::*; - - #[cube] - pub fn my_func(x: F) -> F { - x * F::from_int(2) - } -} - -mod here { - use super::*; - - #[cube] - pub fn caller(x: F) { - let _ = x + elsewhere::my_func::(x); - } - - #[cube] - pub fn no_call_ref(x: F) { - let _ = x + x * F::from_int(2); - } -} - -mod tests { - use super::*; - use cubecl_core::ir::Item; - - type ElemType = f32; - - #[test] - fn cube_call_equivalent_to_no_call_no_arg_test() { - let mut caller_context = CubeContext::default(); - let x = caller_context.create_local_binding(Item::new(ElemType::as_elem())); - here::caller::expand::(&mut caller_context, x.into()); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::default(); - let x = no_call_context.create_local_binding(Item::new(ElemType::as_elem())); - here::no_call_ref::expand::(&mut no_call_context, x.into()); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } -} diff --git a/crates/cubecl-core/tests/frontend/ops.rs b/crates/cubecl-core/tests/frontend/ops.rs deleted file mode 100644 index 7034d67ed..000000000 --- a/crates/cubecl-core/tests/frontend/ops.rs +++ /dev/null @@ -1,536 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn add_op(a: T, b: T) -> T { - a + b -} - -#[cube] -pub fn sub_op(a: T, b: T) -> T { - a - b -} - -#[cube] -pub fn mul_op(a: T, b: T) -> T { - a * b -} - -#[cube] -pub fn div_op(a: T, b: T) -> T { - a / b -} - -#[cube] -pub fn abs_op(a: T) -> T { - T::abs(a) -} - -#[cube] -pub fn exp_op(a: F) -> F { - F::exp(a) -} - -#[cube] -pub fn log_op(a: F) -> F { - F::log(a) -} - -#[cube] -pub fn log1p_op(a: F) -> F { - F::log1p(a) -} - -#[cube] -pub fn cos_op(a: F) -> F { - F::cos(a) -} - -#[cube] -pub fn sin_op(a: F) -> F { - F::sin(a) -} - -#[cube] -pub fn tanh_op(a: F) -> F { - F::tanh(a) -} - -#[cube] -pub fn powf_op(a: F, b: F) -> F { - F::powf(a, b) -} - -#[cube] -pub fn sqrt_op(a: F) -> F { - F::sqrt(a) -} - -#[cube] -pub fn round_op(a: F) -> F { - F::round(a) -} - -#[cube] -pub fn floor_op(a: F) -> F { - F::floor(a) -} - -#[cube] -pub fn ceil_op(a: F) -> F { - F::ceil(a) -} - -#[cube] -pub fn erf_op(a: F) -> F { - F::erf(a) -} - -#[cube] -pub fn recip_op(a: F) -> F { - F::recip(a) -} - -#[cube] -pub fn equal_op(a: T, b: T) -> bool { - a == b -} - -#[cube] -pub fn not_equal_op(a: T, b: T) -> bool { - a != b -} - -#[cube] -pub fn lower_op(a: T, b: T) -> bool { - a < b -} - -#[cube] -pub fn greater_op(a: T, b: T) -> bool { - a > b -} - -#[cube] -pub fn lower_equal_op(a: T, b: T) -> bool { - a <= b -} - -#[cube] -pub fn greater_equal_op(a: T, b: T) -> bool { - a >= b -} - -#[cube] -pub fn modulo_op(a: u32, b: u32) -> u32 { - a % b -} - -#[cube] -pub fn remainder_op(a: T, b: T) -> T { - T::rem(a, b) -} - -#[cube] -pub fn max_op(a: T, b: T) -> T { - T::max(a, b) -} - -#[cube] -pub fn min_op(a: T, b: T) -> T { - T::min(a, b) -} - -#[cube] -pub fn and_op(a: bool, b: bool) -> bool { - a && b -} - -#[cube] -pub fn or_op(a: bool, b: bool) -> bool { - a || b -} - -#[cube] -pub fn not_op(a: bool) -> bool { - !a -} - -#[cube] -pub fn bitand_op(a: u32, b: u32) -> u32 { - a & b -} - -#[cube] -pub fn bitor_op(a: u32, b: u32) -> u32 { - a | b -} - -#[cube] -pub fn bitxor_op(a: u32, b: u32) -> u32 { - a ^ b -} - -#[cube] -pub fn shl_op(a: u32, b: u32) -> u32 { - a << b -} - -#[cube] -pub fn shr_op(a: u32, b: u32) -> u32 { - a >> b -} - -#[cube] -pub fn add_assign_op(mut a: T, b: T) { - a += b; -} - -#[cube] -pub fn sub_assign_op(mut a: T, b: T) { - a -= b; -} - -#[cube] -pub fn mul_assign_op(mut a: T, b: T) { - a *= b; -} - -#[cube] -pub fn div_assign_op(mut a: T, b: T) { - a /= b; -} - -#[cube] -pub fn rem_assign_op(mut a: T, b: T) { - a %= b; -} - -#[cube] -pub fn bitor_assign_op(mut a: T, b: T) { - a |= b; -} - -#[cube] -pub fn bitand_assign_op(mut a: T, b: T) { - a &= b; -} - -#[cube] -pub fn bitxor_assign_op(mut a: T, b: T) { - a ^= b; -} - -#[cube] -pub fn shl_assign_op(mut a: T, b: u32) { - a <<= b; -} - -#[cube] -pub fn shr_assign_op(mut a: T, b: u32) { - a >>= b; -} - -mod tests { - use super::*; - use cubecl_core::ir::{Elem, FloatKind, Item}; - use pretty_assertions::assert_eq; - - macro_rules! binary_test { - ($test_name:ident, $op_expand:expr, $op_name:expr, $func:ident) => { - #[test] - fn $test_name() { - let mut context = CubeContext::default(); - let x = context.create_local_binding(Item::new(Elem::Float(FloatKind::F32))); - let y = context.create_local_binding(Item::new(Elem::Float(FloatKind::F32))); - - $op_expand(&mut context, x.into(), y.into()); - - assert_eq!( - format!("{:?}", context.into_scope().process().operations), - $func($op_name) - ); - } - }; - } - - macro_rules! unary_test { - ($test_name:ident, $op_expand:expr, $op_name:expr) => { - #[test] - fn $test_name() { - let mut context = CubeContext::default(); - let x = context.create_local_binding(Item::new(Elem::Float(FloatKind::F32))); - - $op_expand(&mut context, x.into()); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_unary($op_name) - ); - } - }; - } - - macro_rules! binary_boolean_test { - ($test_name:ident, $op_expand:expr, $op_name:expr) => { - #[test] - fn $test_name() { - let mut context = CubeContext::default(); - let x = context.create_local_binding(Item::new(Elem::Bool)); - let y = context.create_local_binding(Item::new(Elem::Bool)); - - $op_expand(&mut context, x.into(), y.into()); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_binary_boolean($op_name) - ); - } - }; - } - - macro_rules! binary_u32_test { - ($test_name:ident, $op_expand:expr, $op_name:expr) => { - #[test] - fn $test_name() { - let mut context = CubeContext::default(); - let x = context.create_local_binding(Item::new(u32::as_elem())); - let y = context.create_local_binding(Item::new(u32::as_elem())); - - $op_expand(&mut context, x.into(), y.into()); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_binary_u32($op_name) - ); - } - }; - } - - binary_test!(cube_can_add, add_op::expand::, "Add", ref_ops_binary); - binary_test!(cube_can_sub, sub_op::expand::, "Sub", ref_ops_binary); - binary_test!(cube_can_mul, mul_op::expand::, "Mul", ref_ops_binary); - binary_test!(cube_can_div, div_op::expand::, "Div", ref_ops_binary); - unary_test!(cube_can_abs, abs_op::expand::, "Abs"); - unary_test!(cube_can_exp, exp_op::expand::, "Exp"); - unary_test!(cube_can_log, log_op::expand::, "Log"); - unary_test!(cube_can_log1p, log1p_op::expand::, "Log1p"); - unary_test!(cube_can_cos, cos_op::expand::, "Cos"); - unary_test!(cube_can_sin, sin_op::expand::, "Sin"); - unary_test!(cube_can_tanh, tanh_op::expand::, "Tanh"); - binary_test!( - cube_can_powf, - powf_op::expand::, - "Powf", - ref_ops_binary - ); - unary_test!(cube_can_sqrt, sqrt_op::expand::, "Sqrt"); - unary_test!(cube_can_erf, erf_op::expand::, "Erf"); - unary_test!(cube_can_recip, recip_op::expand::, "Recip"); - unary_test!(cube_can_round, round_op::expand::, "Round"); - unary_test!(cube_can_floor, floor_op::expand::, "Floor"); - unary_test!(cube_can_ceil, ceil_op::expand::, "Ceil"); - binary_test!(cube_can_eq, equal_op::expand::, "Equal", ref_ops_cmp); - binary_test!( - cube_can_ne, - not_equal_op::expand::, - "NotEqual", - ref_ops_cmp - ); - binary_test!(cube_can_lt, lower_op::expand::, "Lower", ref_ops_cmp); - binary_test!( - cube_can_le, - lower_equal_op::expand::, - "LowerEqual", - ref_ops_cmp - ); - binary_test!( - cube_can_ge, - greater_equal_op::expand::, - "GreaterEqual", - ref_ops_cmp - ); - binary_test!( - cube_can_gt, - greater_op::expand::, - "Greater", - ref_ops_cmp - ); - binary_test!(cube_can_max, max_op::expand::, "Max", ref_ops_binary); - binary_test!(cube_can_min, min_op::expand::, "Min", ref_ops_binary); - binary_test!( - cube_can_add_assign, - add_assign_op::expand::, - "Add", - ref_ops_binary_assign - ); - binary_test!( - cube_can_sub_assign, - sub_assign_op::expand::, - "Sub", - ref_ops_binary_assign - ); - binary_test!( - cube_can_mul_assign, - mul_assign_op::expand::, - "Mul", - ref_ops_binary_assign - ); - binary_test!( - cube_can_div_assign, - div_assign_op::expand::, - "Div", - ref_ops_binary_assign - ); - binary_test!( - cube_can_rem_assign, - rem_assign_op::expand::, - "Modulo", - ref_ops_binary_assign - ); - binary_test!( - cube_can_bitor_assign, - bitor_assign_op::expand::, - "BitwiseOr", - ref_ops_binary_assign - ); - binary_test!( - cube_can_bitand_assign, - bitand_assign_op::expand::, - "BitwiseAnd", - ref_ops_binary_assign - ); - binary_test!( - cube_can_bitxor_assign, - bitxor_assign_op::expand::, - "BitwiseXor", - ref_ops_binary_assign - ); - binary_test!( - cube_can_shl_assign, - shl_assign_op::expand::, - "ShiftLeft", - ref_ops_binary_assign - ); - binary_test!( - cube_can_shr_assign, - shr_assign_op::expand::, - "ShiftRight", - ref_ops_binary_assign - ); - binary_boolean_test!(cube_can_and, and_op::expand, "And"); - binary_boolean_test!(cube_can_or, or_op::expand, "Or"); - binary_u32_test!(cube_can_bitand, bitand_op::expand, "BitwiseAnd"); - binary_u32_test!(cube_can_bitor, bitor_op::expand, "BitwiseOr"); - binary_u32_test!(cube_can_bitxor, bitxor_op::expand, "BitwiseXor"); - binary_u32_test!(cube_can_shl, shl_op::expand, "ShiftLeft"); - binary_u32_test!(cube_can_shr, shr_op::expand, "ShiftRight"); - binary_u32_test!(cube_can_mod, modulo_op::expand, "Modulo"); - binary_test!( - cube_can_rem, - remainder_op::expand::, - "Remainder", - ref_ops_binary - ); - - #[test] - fn cube_can_not() { - let mut context = CubeContext::default(); - let x = context.create_local_binding(Item::new(Elem::Bool)); - - not_op::expand(&mut context, x.into()); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_unary_boolean("Not") - ); - } - - fn ref_ops_binary_assign(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Float(F32)", true, true) - } - - fn ref_ops_binary(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Float(F32)", true, false) - } - - fn ref_ops_unary(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Float(F32)", false, false) - } - - fn ref_ops_cmp(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Bool", true, false) - } - - fn ref_ops_unary_boolean(ops_name: &str) -> String { - ref_ops_template(ops_name, "Bool", "Bool", false, false) - } - - fn ref_ops_binary_boolean(ops_name: &str) -> String { - ref_ops_template(ops_name, "Bool", "Bool", true, false) - } - - fn ref_ops_binary_u32(ops_name: &str) -> String { - ref_ops_template(ops_name, "UInt(U32)", "UInt(U32)", true, false) - } - - fn ref_ops_template( - ops_name: &str, - in_type: &str, - out_type: &str, - binary: bool, - is_assign: bool, - ) -> String { - if binary { - let out_number = match (in_type == out_type, is_assign) { - (true, true) => 0, - (true, false) => binary as i32, - _ => 2, - }; - format!( - "[Instruction {{ out: Some(Variable {{ \ - kind: Local {{ id: {out_number}, depth: 0 }}, \ - item: Item {{ \ - elem: {out_type}, \ - vectorization: None \ - }} \ - }}), \ - operation: Operator({ops_name}(BinaryOperator {{ \ - lhs: Variable {{ \ - kind: Local {{ id: 0, depth: 0 }}, \ - item: Item {{ \ - elem: {in_type}, \ - vectorization: None \ - }} \ - }}, \ - rhs: Variable {{ \ - kind: Local {{ id: 1, depth: 0 }}, \ - item: Item {{ \ - elem: {in_type}, \ - vectorization: None \ - }} \ - }} \ - }})) }}]" - ) - } else { - format!( - "[Instruction {{ out: Some(Variable {{ \ - kind: Local {{ id: 0, depth: 0 }}, \ - item: Item {{ \ - elem: {out_type}, \ - vectorization: None \ - }} \ - }}), \ - operation: Operator({ops_name}(UnaryOperator {{ \ - input: Variable {{ \ - kind: Local {{ id: 0, depth: 0 }}, \ - item: Item {{ \ - elem: {in_type}, \ - vectorization: None \ - }} \ - }} \ - }})) }}]" - ) - } - } -} diff --git a/crates/cubecl-core/tests/frontend/parenthesis.rs b/crates/cubecl-core/tests/frontend/parenthesis.rs deleted file mode 100644 index 77368e8e3..000000000 --- a/crates/cubecl-core/tests/frontend/parenthesis.rs +++ /dev/null @@ -1,50 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn parenthesis(x: T, y: T, z: T) -> T { - x * (y + z) -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - use pretty_assertions::assert_eq; - - type ElemType = f32; - - #[test] - fn cube_parenthesis_priority_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - let z = context.create_local_binding(Item::new(ElemType::as_elem())); - - parenthesis::expand::(&mut context, x.into(), y.into(), z.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref()); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - let y = context.create_local_binding(item); - let z = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - let z: Variable = z.into(); - - cpa!(scope, z = y + z); - cpa!(scope, z = x * z); - - format!("{:#?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/redeclare.rs b/crates/cubecl-core/tests/frontend/redeclare.rs deleted file mode 100644 index db9a81039..000000000 --- a/crates/cubecl-core/tests/frontend/redeclare.rs +++ /dev/null @@ -1,205 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn redeclare_same_scope(mut x: I) { - let i = I::new(1); - x += i; - let i = I::new(2); - x += i; -} - -#[cube] -pub fn redeclare_same_scope_other_type(mut x: I) -> F { - let i = I::new(1); - x += i; - let i = F::new(2.); - i + i -} - -#[cube] -pub fn redeclare_different_scope(mut x: I) { - let y = I::new(1); - x += y; - for _ in 0..2u32 { - let y = I::new(2); - x += y; - } -} - -#[cube] -#[allow(unused)] -pub fn redeclare_two_for_loops(mut x: u32) { - for i in 0..2 { - x += i; - } - for i in 0..2 { - x += i; - x += i; - } -} - -mod tests { - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - use pretty_assertions::assert_eq; - - use super::*; - - type ElemType = i32; - - #[test] - fn cube_redeclare_same_scope_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - - redeclare_same_scope::expand::(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_same_scope() - ); - } - - #[test] - fn cube_redeclare_same_scope_other_type_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - - redeclare_same_scope_other_type::expand::(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_same_scope_other_type() - ); - } - - #[test] - fn cube_redeclare_different_scope_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - - redeclare_different_scope::expand::(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_different() - ); - } - - #[test] - fn cube_redeclare_two_for_loops_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(u32::as_elem())); - - redeclare_two_for_loops::expand(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_two_for_loops() - ); - } - - fn inline_macro_ref_same_scope() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let x = context.create_local_binding(item); - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - let value: ExpandElement = ElemType::from(1).into(); - let value: Variable = *value; - - cpa!(scope, x += value); - - let value: ExpandElement = ElemType::from(2).into(); - let value: Variable = *value; - - cpa!(scope, x += value); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_same_scope_other_type() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let x = context.create_local_binding(item); - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - let i: ExpandElement = ElemType::new(1).into(); - let i = *i; - cpa!(scope, x += i); - let i: ExpandElement = 2f32.into(); - let i = *i; - let y = scope.create_local(Item::new(f32::as_elem())); - cpa!(scope, y = i + i); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_different() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let x = context.create_local_binding(item); - let end = 2u32; - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - let y: ExpandElement = ElemType::new(1).into(); - let y = *y; - cpa!(scope, x += y); - - cpa!( - &mut scope, - range(0u32, end, false).for_each(|_, scope| { - let value: ExpandElement = ElemType::new(2).into(); - let value: Variable = *value; - - cpa!(scope, x += value); - }) - ); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_two_for_loops() -> String { - let mut context = CubeContext::default(); - let item = Item::new(u32::as_elem()); - - let x = context.create_local_binding(item); - let end = 2u32; - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - cpa!( - &mut scope, - range(0u32, end, false).for_each(|i, scope| { - cpa!(scope, x += i); - }) - ); - - cpa!( - &mut scope, - range(0u32, end, false).for_each(|i, scope| { - cpa!(scope, x += i); - cpa!(scope, x += i); - }) - ); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/reuse.rs b/crates/cubecl-core/tests/frontend/reuse.rs deleted file mode 100644 index 6ff636f89..000000000 --- a/crates/cubecl-core/tests/frontend/reuse.rs +++ /dev/null @@ -1,109 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -#[allow(clippy::assign_op_pattern)] -pub fn reuse(mut x: I) { - // a += b is more efficient than a = a + b - // Because the latter does not assume that a is the same in lhs and rhs - // Normally clippy should detect it - while x < I::from_int(10) { - x = x + I::from_int(1); - } -} - -#[cube] -pub fn reuse_incr(mut x: I) { - while x < I::from_int(10) { - x += I::from_int(1); - } -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Branch, Elem, Item, Variable}, - }; - use pretty_assertions::assert_eq; - - type ElemType = i32; - #[test] - fn cube_reuse_assign_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - - reuse::expand::(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_assign() - ); - } - - #[test] - fn cube_reuse_incr_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - - reuse_incr::expand::(&mut context, x.into()); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref_incr()); - } - - fn inline_macro_ref_assign() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let x: Variable = x.into(); - let tmp = scope.create_local(item); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = x < 10); - cpa!(scope, cond = !cond); - cpa!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - cpa!(scope, tmp = x + 1); - cpa!(scope, x = tmp); - }) - ); - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_incr() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let x: Variable = x.into(); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = x < 10); - cpa!(scope, cond = !cond); - cpa!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - cpa!(scope, x = x + 1); - }) - ); - - format!("{:#?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/shared_memory.rs b/crates/cubecl-core/tests/frontend/shared_memory.rs deleted file mode 100644 index fe7b4046b..000000000 --- a/crates/cubecl-core/tests/frontend/shared_memory.rs +++ /dev/null @@ -1,50 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn shared_memory_read_write(#[comptime] sm_size: u32) { - let mut shared = SharedMemory::::new(sm_size); - shared[0] = T::from_int(3); - let _ = shared[0]; -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = f32; - - #[test] - fn cube_support_shared_memory() { - let mut context = CubeContext::default(); - - shared_memory_read_write::expand::(&mut context, 512); - assert_eq!( - format!("{:?}", context.into_scope().operations), - inline_macro_ref() - ); - } - - fn inline_macro_ref() -> String { - let context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let var = scope.create_local(item); - let pos: Variable = 0u32.into(); - - // Create - let shared = scope.create_shared(item, 512); - - // Write - cpa!(scope, shared[pos] = 3.0_f32); - - // Read - cpa!(scope, var = shared[pos]); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/struct.rs b/crates/cubecl-core/tests/frontend/struct.rs deleted file mode 100644 index ce4ed6c6d..000000000 --- a/crates/cubecl-core/tests/frontend/struct.rs +++ /dev/null @@ -1,191 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[derive(CubeLaunch)] -pub struct EmptyLaunch {} - -#[derive(CubeType)] -pub struct EmptyType {} - -#[derive(CubeLaunch)] -pub struct UnitLaunch; - -#[derive(CubeLaunch)] -pub struct WithField { - lhs: Array, - rhs: Array, -} - -#[derive(CubeLaunch)] -pub struct WithFieldGeneric { - lhs: Array, - rhs: Array, -} - -#[derive(CubeLaunch)] -pub struct WithFieldGenericAndComptime { - lhs: Array, - #[cube(comptime)] - my_tag: String, -} - -#[derive(CubeType)] -pub struct UnitType; - -#[derive(CubeType)] -pub struct State { - first: T, - second: T, -} - -#[cube] -pub fn state_receiver_with_reuse(state: State) -> T { - let x = state.first + state.second; - state.second + x + state.first -} - -#[cube] -pub fn attribute_modifier_reuse_field(mut state: State) -> T { - state.first = T::from_int(4); - state.first -} - -#[cube] -pub fn attribute_modifier_reuse_struct(mut state: State) -> State { - state.first = T::from_int(4); - state -} - -#[cube] -fn creator(x: T, second: T) -> State { - let mut state = State:: { first: x, second }; - state.second = state.first; - - state -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = f32; - - #[test] - fn cube_new_struct_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - creator::expand::(&mut context, x.into(), y.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - creator_inline_macro_ref() - ); - } - - #[test] - fn cube_struct_as_arg_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - let expanded_state = StateExpand { - first: x.into(), - second: y.into(), - }; - state_receiver_with_reuse::expand::(&mut context, expanded_state); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - receive_state_with_reuse_inline_macro_ref() - ); - } - - #[test] - fn cube_struct_assign_to_field_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - let expanded_state = StateExpand { - first: x.into(), - second: y.into(), - }; - attribute_modifier_reuse_field::expand::(&mut context, expanded_state); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - field_modifier_inline_macro_ref() - ); - } - - #[test] - fn cube_struct_assign_to_field_reuse_struct_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - let expanded_state = StateExpand { - first: x.into(), - second: y.into(), - }; - attribute_modifier_reuse_struct::expand::(&mut context, expanded_state); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - field_modifier_inline_macro_ref() - ); - } - - fn creator_inline_macro_ref() -> String { - let context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let x = scope.create_local(item); - let y = scope.create_local(item); - cpa!(scope, y = x); - - format!("{:?}", scope.operations) - } - - fn field_modifier_inline_macro_ref() -> String { - let context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - scope.create_with_value(4, item); - - format!("{:?}", scope.operations) - } - - fn receive_state_with_reuse_inline_macro_ref() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - let y = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - let z = scope.create_local(item); - - cpa!(scope, z = x + y); - cpa!(scope, z = y + z); - cpa!(scope, z = z + x); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/tensor.rs b/crates/cubecl-core/tests/frontend/tensor.rs deleted file mode 100644 index 274f53d2d..000000000 --- a/crates/cubecl-core/tests/frontend/tensor.rs +++ /dev/null @@ -1,46 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn kernel(input: &Tensor) { - let _shape = input.shape(1); - let _stride = input.stride(1); - let _length = input.len(); -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Instruction, Item, Variable}, - }; - - type ElemType = f32; - - #[test] - fn cube_support_tensor_metadata() { - let mut context = CubeContext::default(); - let input = context.input(0, Item::new(ElemType::as_elem())); - - kernel::expand::(&mut context, input.into()); - assert_eq!(context.into_scope().operations, inline_macro_ref()); - } - - fn inline_macro_ref() -> Vec { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let input = context.input(0, item); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let x = scope.create_local(Item::new(u32::as_elem())); - let y = scope.create_local(Item::new(u32::as_elem())); - let z = scope.create_local(Item::new(u32::as_elem())); - - cpa!(&mut scope, x = shape(input, 1u32)); - cpa!(&mut scope, y = stride(input, 1u32)); - cpa!(&mut scope, z = len(input)); - - scope.operations - } -} diff --git a/crates/cubecl-core/tests/frontend/topology.rs b/crates/cubecl-core/tests/frontend/topology.rs deleted file mode 100644 index ceae93bfb..000000000 --- a/crates/cubecl-core/tests/frontend/topology.rs +++ /dev/null @@ -1,47 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn topology_kernel(input: Tensor) { - let x = ABSOLUTE_POS + 4; - let _ = input[x]; -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Builtin, Item, Variable}, - }; - - type ElemType = f32; - - #[test] - fn cube_support_topology() { - let mut context = CubeContext::default(); - let input = context.input(0, Item::new(ElemType::as_elem())); - - topology_kernel::expand::(&mut context, input.into()); - assert_eq!( - format!("{:?}", context.into_scope().operations), - inline_macro_ref() - ); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let input = context.input(0, item); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let x = scope.create_local(Item::new(u32::as_elem())); - let y = scope.create_local(item); - - let id = Variable::builtin(Builtin::AbsolutePos); - cpa!(&mut scope, x = id + 4u32); - cpa!(&mut scope, y = input[x]); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/trait.rs b/crates/cubecl-core/tests/frontend/trait.rs deleted file mode 100644 index 0108f691e..000000000 --- a/crates/cubecl-core/tests/frontend/trait.rs +++ /dev/null @@ -1,187 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -/// Traits used in Cube kernels must expose an _expand variant -/// for all their methods. However, one does not need to provide its -/// implementation, see examples below. -#[cube] -pub trait Strategy { - fn operation(input_1: T, input_2: T) -> T; -} - -struct AddStrategy; - -#[cube] -/// The actual implementation of AddStrategy's operation -/// Automatically generated an _expand variant -pub fn add_strategy_operation(input_1: T, input_2: T) -> T { - input_1 + input_2 -} - -#[cube] -impl Strategy for AddStrategy { - fn operation(input_1: T, input_2: T) -> T { - add_strategy_operation::(input_1, input_2) - } -} - -struct SubStrategy; - -#[cube] -impl Strategy for SubStrategy { - fn operation(input_1: T, input_2: T) -> T { - input_1 - input_2 - } -} - -#[cube] -pub fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { - S::operation(x, y) -} - -#[cube] -pub fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { - let z = S1::operation(x, y); - S2::operation(z, y) -} - -pub trait MethodTypedStrategy { - fn operation(input_1: T, input_2: T) -> T; - fn __expand_operation( - _context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType; -} - -impl MethodTypedStrategy for AddStrategy { - fn operation(input_1: T, input_2: T) -> T { - add_strategy_operation(input_1, input_2) - } - - fn __expand_operation( - context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType { - add_strategy_operation::expand::(context, input_1, input_2) - } -} - -#[cube] -pub fn with_trait_generic_method(x: T, y: T) -> T { - S::operation::(x, y) -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Item, Variable}, - }; - use pretty_assertions::assert_eq; - - type ElemType = f32; - #[test] - fn cube_strategy_trait_add_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - with_strategy_trait::expand::(&mut context, x.into(), y.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_one(true) - ); - } - - #[test] - fn cube_strategy_trait_sub_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - with_strategy_trait::expand::(&mut context, x.into(), y.into()); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_one(false) - ); - } - - #[test] - fn cube_two_strategy_traits_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - two_strategy_traits::expand::( - &mut context, - x.into(), - y.into(), - ); - let scope = context.into_scope(); - - assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref_two()); - } - - #[test] - fn cube_trait_generic_method_test() { - let mut context = CubeContext::default(); - - let x = context.create_local_binding(Item::new(ElemType::as_elem())); - let y = context.create_local_binding(Item::new(ElemType::as_elem())); - - with_trait_generic_method::expand::( - &mut context, - x.into(), - y.into(), - ); - let scope = context.into_scope(); - - assert_eq!( - format!("{:#?}", scope.operations), - inline_macro_ref_one(true) - ); - } - - fn inline_macro_ref_one(is_add_strategy: bool) -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - let y = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - - match is_add_strategy { - true => cpa!(scope, y = x + y), - false => cpa!(scope, y = x - y), - } - - format!("{:#?}", scope.operations) - } - - fn inline_macro_ref_two() -> String { - let mut context = CubeContext::default(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local_binding(item); - let y = context.create_local_binding(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - - cpa!(scope, x = x - y); - cpa!(scope, y = x + y); - - format!("{:#?}", scope.operations) - } -} diff --git a/crates/cubecl-core/tests/frontend/tuple.rs b/crates/cubecl-core/tests/frontend/tuple.rs deleted file mode 100644 index 94f3af057..000000000 --- a/crates/cubecl-core/tests/frontend/tuple.rs +++ /dev/null @@ -1,76 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn tuple_const() -> (u32, u32) { - let x = 0u32; - let y = 1u32; - (x, y) -} - -#[cube] -pub fn tuple_destructuring() -> (u32, u32) { - let x = (0u32, 1u32); - let (a, b) = x; - (a + 1, b) -} - -mod tests { - use super::*; - use cubecl_core::{ - cpa, - ir::{Instruction, Item, Variable}, - }; - use pretty_assertions::assert_eq; - - #[test] - fn cube_tuple_const_test() { - let mut context = CubeContext::default(); - - tuple_const::expand(&mut context); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_ref_tuple_const()); - } - - fn inline_macro_ref_tuple_const() -> Vec { - let context = CubeContext::default(); - - let mut scope = context.into_scope(); - let x = scope.create_local(Item::new(u32::as_elem())); - let y = scope.create_local(Item::new(u32::as_elem())); - - let zero: Variable = 0u32.into(); - let one: Variable = 1u32.into(); - - cpa!(scope, x = zero); - cpa!(scope, y = one); - - scope.operations - } - - #[test] - fn cube_tuple_destructuring() { - let mut context = CubeContext::default(); - - tuple_destructuring::expand(&mut context); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_ref_tuple_destructuring()); - } - - fn inline_macro_ref_tuple_destructuring() -> Vec { - let context = CubeContext::default(); - - let mut scope = context.into_scope(); - let a = scope.create_local(Item::new(u32::as_elem())); - let b = scope.create_local(Item::new(u32::as_elem())); - - let one: Variable = 1u32.into(); - - cpa!(scope, a = one); - cpa!(scope, b = one); - - scope.operations - } -} diff --git a/crates/cubecl-core/tests/frontend/vectorization.rs b/crates/cubecl-core/tests/frontend/vectorization.rs deleted file mode 100644 index 5e695522d..000000000 --- a/crates/cubecl-core/tests/frontend/vectorization.rs +++ /dev/null @@ -1,72 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[cube] -pub fn vectorization_binary(lhs: T) { - let _ = lhs + T::from_vec([4, 5]); -} - -#[cube] -pub fn vectorization_cmp(rhs: T) { - let _ = T::from_vec([4, 5]) > rhs; -} - -mod tests { - use std::num::NonZero; - - use super::*; - use cubecl_core::ir::Item; - - type ElemType = f32; - - #[test] - fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() { - let mut context = CubeContext::default(); - - let lhs = - context.create_local_binding(Item::vectorized(ElemType::as_elem(), NonZero::new(2))); - - vectorization_binary::expand::(&mut context, lhs.into()); - } - - #[test] - #[should_panic] - fn cube_vectorization_binary_op_with_different_scheme_fails() { - let mut context = CubeContext::default(); - - let lhs = - context.create_local_binding(Item::vectorized(ElemType::as_elem(), NonZero::new(4))); - - vectorization_binary::expand::(&mut context, lhs.into()); - } - - #[test] - fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() { - let mut context = CubeContext::default(); - - let lhs = - context.create_local_binding(Item::vectorized(ElemType::as_elem(), NonZero::new(2))); - - vectorization_cmp::expand::(&mut context, lhs.into()); - } - - #[test] - #[should_panic] - fn cube_vectorization_cmp_op_with_different_scheme_fails() { - let mut context = CubeContext::default(); - - let lhs = - context.create_local_binding(Item::vectorized(ElemType::as_elem(), NonZero::new(4))); - - vectorization_cmp::expand::(&mut context, lhs.into()); - } - - #[test] - fn cube_vectorization_can_be_broadcasted() { - let mut context = CubeContext::default(); - - let lhs = context.create_local_binding(Item::vectorized(ElemType::as_elem(), None)); - - vectorization_cmp::expand::(&mut context, lhs.into()); - } -} diff --git a/crates/cubecl-core/tests/mod.rs b/crates/cubecl-core/tests/mod.rs index 40398e64c..d73fdc3a8 100644 --- a/crates/cubecl-core/tests/mod.rs +++ b/crates/cubecl-core/tests/mod.rs @@ -1,5 +1,3 @@ -mod frontend; - #[test] fn compile_fail_tests() { let t = trybuild::TestCases::new(); diff --git a/crates/cubecl-cpp/src/cuda/dialect.rs b/crates/cubecl-cpp/src/cuda/dialect.rs index 4da8c82bc..136243b73 100644 --- a/crates/cubecl-cpp/src/cuda/dialect.rs +++ b/crates/cubecl-cpp/src/cuda/dialect.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::shared::{Dialect, IndexedVariable, Variable, WmmaCompiler}; +use crate::shared::{Dialect, WmmaCompiler}; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct CudaDialect { @@ -74,20 +74,19 @@ impl> Dialect for CudaDialect { fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("__nv_bfloat162") } - - fn warp_shuffle(input: &IndexedVariable, id: &Variable) -> String { - format!("__shfl_sync(-1, {input}, {id})") + fn warp_shuffle(var: &str, source: &str) -> String { + format!("__shfl_sync(-1, {var}, {source})") } - fn warp_shuffle_xor(out: &IndexedVariable) -> String { - format!("__shfl_xor_sync(-1, {out}, offset)") + fn warp_shuffle_xor(var: &str, offset: &str) -> String { + format!("__shfl_xor_sync(-1, {var}, {offset})") } - fn warp_shuffle_down(out: &IndexedVariable) -> String { - format!("__shfl_down_sync(-1, {out}, offset)") + fn warp_shuffle_down(var: &str, offset: &str) -> String { + format!("__shfl_down_sync(-1, {var}, {offset})") } - fn warp_all(out: &IndexedVariable) -> String { - format!("__all_sync(-1, {out})") + fn warp_all(var: &str) -> String { + format!("__all_sync(-1, {var})") } - fn warp_any(out: &IndexedVariable) -> String { - format!("__any_sync(-1, {out})") + fn warp_any(var: &str) -> String { + format!("__any_sync(-1, {var})") } } diff --git a/crates/cubecl-cpp/src/hip/dialect.rs b/crates/cubecl-cpp/src/hip/dialect.rs index 38d0c4d5e..2f5ad02eb 100644 --- a/crates/cubecl-cpp/src/hip/dialect.rs +++ b/crates/cubecl-cpp/src/hip/dialect.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::shared::{Dialect, IndexedVariable, Variable, WmmaCompiler}; +use crate::shared::{Dialect, WmmaCompiler}; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct HipDialect { @@ -75,20 +75,19 @@ impl> Dialect for HipDialect { // "hip_bfloat16.h" has no "hip_bfloat162" type f.write_str("hip_bfloat16") } - - fn warp_shuffle(input: &IndexedVariable, id: &Variable) -> String { - format!("__shfl({input}, {id})") + fn warp_shuffle(var: &str, source: &str) -> String { + format!("__shfl({var}, {source})") } - fn warp_shuffle_xor(out: &IndexedVariable) -> String { - format!("__shfl_xor({out}, offset)") + fn warp_shuffle_xor(var: &str, offset: &str) -> String { + format!("__shfl_xor_sync({var}, {offset})") } - fn warp_shuffle_down(out: &IndexedVariable) -> String { - format!("__shfl_down({out}, offset)") + fn warp_shuffle_down(var: &str, offset: &str) -> String { + format!("__shfl_down_sync({var}, {offset})") } - fn warp_all(out: &IndexedVariable) -> String { - format!("__all({out})") + fn warp_all(var: &str) -> String { + format!("__all({var})") } - fn warp_any(out: &IndexedVariable) -> String { + fn warp_any(out: &str) -> String { format!("__any({out})") } } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index eae8fac37..2be8a197f 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1,18 +1,17 @@ use std::hash::Hash; use std::{collections::HashSet, fmt::Debug, num::NonZero}; -use cubecl_core::ir::{expand_checked_index, expand_checked_index_assign}; +use cubecl_core::ir::{expand_checked_index_assign, Allocator}; use cubecl_core::{ ir::{self as gpu}, - prelude::CubePrimitive, Compiler, Feature, }; use cubecl_runtime::{DeviceProperties, ExecutionMode}; use super::{ AtomicKind, BinaryInstruction, Binding, Body, ComputeKernel, ConstArray, Elem, Fragment, - FragmentIdent, FragmentLayout, IndexedVariable, Instruction, Item, LocalArray, SharedMemory, - UnaryInstruction, Variable, VariableSettings, WarpInstruction, WmmaCompiler, WmmaInstruction, + FragmentIdent, FragmentLayout, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction, + Variable, VariableSettings, WarpInstruction, WmmaCompiler, WmmaInstruction, }; pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 = @@ -29,11 +28,11 @@ pub trait Dialect: fn bfloat16_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; // warp instructions (all threads participating) - fn warp_shuffle(input: &IndexedVariable, id: &Variable) -> String; - fn warp_shuffle_xor(out: &IndexedVariable) -> String; - fn warp_shuffle_down(out: &IndexedVariable) -> String; - fn warp_all(out: &IndexedVariable) -> String; - fn warp_any(out: &IndexedVariable) -> String; + fn warp_shuffle(var: &str, source: &str) -> String; + fn warp_shuffle_xor(var: &str, offset: &str) -> String; + fn warp_shuffle_down(var: &str, offset: &str) -> String; + fn warp_all(var: &str) -> String; + fn warp_any(var: &str) -> String; } #[derive(Clone, Debug)] @@ -95,8 +94,8 @@ impl Compiler for CppCompiler { 49152 } - fn local_allocator() -> impl gpu::LocalAllocator { - gpu::ReusingAllocator::default() + fn local_allocator() -> Allocator { + Allocator::new() } } @@ -546,7 +545,8 @@ impl CppCompiler { gpu::Operator::Slice(op) => { if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() { let input = op.input; - let input_len = scope.create_local(gpu::Item::new(u32::as_elem())); + let input_len = + scope.create_local_mut(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32))); instructions.extend(self.compile_scope(scope)); let length = match input.has_buffer_length() { @@ -572,14 +572,29 @@ impl CppCompiler { } } gpu::Operator::Index(op) => { - if let ExecutionMode::Checked = self.strategy { - if op.lhs.has_length() { - expand_checked_index(scope, op.lhs, op.rhs, out); - instructions.extend(self.compile_scope(scope)); - return; - } - }; - instructions.push(Instruction::Index(self.compile_binary(op, out))); + if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() { + let lhs = op.lhs; + let rhs = op.rhs; + + let array_len = + scope.create_local(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32))); + + instructions.extend(self.compile_scope(scope)); + + let length = match lhs.has_buffer_length() { + true => gpu::Metadata::BufferLength { var: lhs }, + false => gpu::Metadata::Length { var: lhs }, + }; + instructions.push(self.compile_metadata(length, Some(array_len))); + instructions.push(Instruction::CheckedIndex { + len: self.compile_variable(array_len), + lhs: self.compile_variable(lhs), + rhs: self.compile_variable(rhs), + out: self.compile_variable(out), + }); + } else { + instructions.push(Instruction::Index(self.compile_binary(op, out))); + } } gpu::Operator::UncheckedIndex(op) => { instructions.push(Instruction::Index(self.compile_binary(op, out))) @@ -672,6 +687,12 @@ impl CppCompiler { gpu::Operator::BitwiseXor(op) => { instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out))) } + gpu::Operator::CountOnes(op) => { + instructions.push(Instruction::CountBits(self.compile_unary(op, out))) + } + gpu::Operator::ReverseBits(op) => { + instructions.push(Instruction::ReverseBits(self.compile_unary(op, out))) + } gpu::Operator::ShiftLeft(op) => { instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out))) } @@ -800,17 +821,17 @@ impl CppCompiler { gpu::VariableKind::GlobalScalar(id) => { Variable::GlobalScalar(id, self.compile_item(item).elem, item.elem) } - gpu::VariableKind::Local { id, depth } => Variable::Local { + gpu::VariableKind::LocalMut { id, depth } => Variable::LocalMut { id, item: self.compile_item(item), depth, }, - gpu::VariableKind::Versioned { id, depth, .. } => Variable::Local { + gpu::VariableKind::Versioned { id, depth, .. } => Variable::LocalMut { id, item: self.compile_item(item), depth, }, - gpu::VariableKind::LocalBinding { id, depth } => Variable::ConstLocal { + gpu::VariableKind::LocalConst { id, depth } => Variable::LocalConst { id, item: self.compile_item(item), depth, diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index 29d2f2349..f8e8d8512 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -276,7 +276,7 @@ impl Binary for IndexAssign { rhs: &Variable, out: &Variable, ) -> std::fmt::Result { - if matches!(out, Variable::Local { .. } | Variable::ConstLocal { .. }) { + if matches!(out, Variable::LocalMut { .. } | Variable::LocalConst { .. }) { return IndexAssignVector::format(f, lhs, rhs, out); }; @@ -299,7 +299,7 @@ impl Binary for Index { rhs: &Variable, out: &Variable, ) -> std::fmt::Result { - if matches!(lhs, Variable::Local { .. } | Variable::ConstLocal { .. }) { + if matches!(lhs, Variable::LocalMut { .. } | Variable::LocalConst { .. }) { return IndexVector::format(f, lhs, rhs, out); } @@ -386,8 +386,12 @@ impl IndexVector { Variable::ConstantScalar(value, _elem) => value.as_usize(), _ => { let elem = out.elem(); + let qualifier = out.const_qualifier(); let out = out.fmt_left(); - return writeln!(f, "{out} = reinterpret_cast<{elem}*>(&{lhs})[{rhs}];"); + return writeln!( + f, + "{out} = reinterpret_cast<{elem}{qualifier}*>(&{lhs})[{rhs}];" + ); } }; diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index cd80b2b70..d6a28d0f4 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -104,7 +104,7 @@ impl Component for IndexedVariable { } fn is_const(&self) -> bool { - matches!(self.var, Variable::ConstLocal { .. }) + matches!(self.var, Variable::LocalConst { .. }) } } @@ -119,8 +119,9 @@ impl Component for Variable { Variable::GlobalOutputArray(_, e) => *e, Variable::SharedMemory(_, e, _) => *e, Variable::ConstantArray(_, e, _) => *e, - Variable::Local { item, .. } => *item, - Variable::ConstLocal { item, .. } => *item, + Variable::LocalMut { item, .. } => *item, + Variable::LocalConst { item, .. } => *item, + Variable::Named { item, .. } => *item, Variable::Slice { item, .. } => *item, Variable::ConstantScalar(_, e) => Item::scalar(*e), Variable::GlobalScalar(_, e, _) => Item::scalar(*e), @@ -157,7 +158,7 @@ impl Component for Variable { } fn is_const(&self) -> bool { - matches!(self, Variable::ConstLocal { .. }) + matches!(self, Variable::LocalConst { .. }) } } @@ -170,16 +171,20 @@ pub enum Variable { GlobalScalar(u16, Elem, gpu::Elem), ConstantArray(u16, Item, u32), ConstantScalar(ConstantScalarValue, Elem), - Local { + LocalMut { id: u16, item: Item, depth: u8, }, - ConstLocal { + LocalConst { id: u16, item: Item, depth: u8, }, + Named { + name: &'static str, + item: Item, + }, Slice { id: u16, item: Item, @@ -222,8 +227,9 @@ impl Display for Variable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), - Variable::Local { id, depth, .. } => f.write_fmt(format_args!("l_{depth}_{id}")), - Variable::ConstLocal { id, depth, .. } => f.write_fmt(format_args!("ssa_{depth}_{id}")), + Variable::LocalMut { id, depth, .. } => f.write_fmt(format_args!("l_mut_{depth}_{id}")), + Variable::LocalConst { id, depth, .. } => f.write_fmt(format_args!("l_{depth}_{id}")), + Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")), Variable::Slice { id, item: _, depth } => { write!(f, "slice_{depth}_{id}") } @@ -362,12 +368,12 @@ impl Variable { Variable::GlobalOutputArray(id, item) => { Variable::GlobalOutputArray(*id, item.optimized()) } - Variable::Local { id, item, depth } => Variable::Local { + Variable::LocalMut { id, item, depth } => Variable::LocalMut { id: *id, item: item.optimized(), depth: *depth, }, - Variable::ConstLocal { id, item, depth } => Variable::ConstLocal { + Variable::LocalConst { id, item, depth } => Variable::LocalConst { id: *id, item: item.optimized(), depth: *depth, @@ -410,8 +416,9 @@ impl Variable { Variable::GlobalOutputArray(_, _) => false, Variable::SharedMemory(_, _, _) => false, Variable::ConstantArray(_, _, _) => false, - Variable::Local { .. } => false, - Variable::ConstLocal { .. } => false, + Variable::LocalMut { .. } => false, + Variable::LocalConst { .. } => false, + Variable::Named { .. } => false, Variable::Slice { .. } => false, Variable::BlockIdxX => true, Variable::BlockIdxY => true, @@ -443,6 +450,14 @@ impl Variable { optimized: self.is_optimized(), } } + + pub fn const_qualifier(&self) -> &str { + if self.is_const() { + " const" + } else { + "" + } + } } pub trait FmtLeft: Display { @@ -452,7 +467,7 @@ pub trait FmtLeft: Display { impl FmtLeft for Variable { fn fmt_left(&self) -> String { match self { - Self::ConstLocal { item, .. } => format!("const {item} {self}"), + Self::LocalConst { item, .. } => format!("const {item} {self}"), Variable::Tmp { item, .. } => format!("{item} {self}"), var => format!("{var}"), } @@ -462,7 +477,7 @@ impl FmtLeft for Variable { impl FmtLeft for IndexedVariable { fn fmt_left(&self) -> String { match self.var { - Variable::ConstLocal { item, .. } => format!("const {item} {self}"), + Variable::LocalConst { item, .. } => format!("const {item} {self}"), Variable::Tmp { item, .. } => format!("{item} {self}"), _ => format!("{self}"), } @@ -485,7 +500,7 @@ pub struct IndexedVariable { impl Display for IndexedVariable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let var = &self.var; - let ref_ = matches!(var, Variable::ConstLocal { .. }) + let ref_ = matches!(var, Variable::LocalConst { .. }) .then_some("const&") .unwrap_or("&"); diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 66fd9f09f..2c58b2b15 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -50,6 +50,12 @@ pub enum Instruction { Sub(BinaryInstruction), Index(BinaryInstruction), IndexAssign(BinaryInstruction), + CheckedIndex { + len: Variable, + lhs: Variable, + rhs: Variable, + out: Variable, + }, Assign(UnaryInstruction), RangeLoop { i: Variable, @@ -111,6 +117,8 @@ pub enum Instruction { BitwiseOr(BinaryInstruction), BitwiseAnd(BinaryInstruction), BitwiseXor(BinaryInstruction), + CountBits(UnaryInstruction), + ReverseBits(UnaryInstruction), ShiftLeft(BinaryInstruction), ShiftRight(BinaryInstruction), Abs(UnaryInstruction), @@ -224,10 +232,27 @@ impl Display for Instruction { Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out), Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out), Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out), + Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out), Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out), Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out), Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::CheckedIndex { len, lhs, rhs, out } => { + let item_out = out.item(); + if let Elem::Atomic(inner) = item_out.elem { + write!(f, "{inner}* {out} = &{lhs}[{rhs}];") + } else { + let out = out.fmt_left(); + write!(f, "{out} = ({rhs} < {len}) ? ")?; + Index::format_scalar(f, *lhs, *rhs, item_out)?; + if item_out.vectorization == 1 { + writeln!(f, " : {item_out}(0);") + } else { + writeln!(f, " : {item_out}{{}};") + } + } + } Instruction::Copy { input, in_index, @@ -311,20 +336,27 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ or_else, out, } => { - let vf_then = then.item().vectorization; - let vf_or_else = or_else.item().vectorization; - let vf_out = out.item().vectorization; - let vf_cond = cond.item().vectorization; + let item_or_else = or_else.item(); + let item_then = then.item(); + let item_out = out.item(); - let vf = usize::max(vf_cond, vf_out); - let vf = usize::max(vf, vf_then); - let vf = usize::max(vf, vf_or_else); + let vf_then = item_then.vectorization; + let vf_or_else = item_or_else.vectorization; + let vf_out = item_out.vectorization; + let vf_cond = cond.item().vectorization; let item_out = out.item(); let cond_elem = cond.item().elem; let out = out.fmt_left(); - if vf > 1 { + let should_broadcast = + vf_cond > 1 || item_out != item_or_else || item_out != item_then; + + if should_broadcast { + let vf = usize::max(vf_cond, vf_out); + let vf = usize::max(vf, vf_then); + let vf = usize::max(vf, vf_or_else); + writeln!(f, "{out} = {item_out} {{")?; for i in 0..vf { let theni = then.index(i); @@ -457,7 +489,10 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ (Elem::U32, Elem::BF16) => { writeln!(f, "{out} = __ushort_as_bfloat16({input});") } - _ => panic!("Unsupported type for bitcasting"), + (Elem::I32, Elem::U32) => { + writeln!(f, "{out} = reinterpret_cast({input});") + } + elem => panic!("Unsupported type for bitcasting {elem:?}"), } } Instruction::AtomicCAS { @@ -694,11 +729,12 @@ impl Remainder { write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?; + let qualifier = out.const_qualifier(); let out = out.fmt_left(); writeln!( f, - "{out} = reinterpret_cast<{item_out_original}&>({out_tmp});\n" + "{out} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n" )?; Ok(()) diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index eca29baf3..3849a9388 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -62,10 +62,11 @@ pub trait Unary { let out_tmp = Variable::tmp(item_out_optimized); write_op(index, elem, &input, &out_tmp)?; - + let qualifier = out.const_qualifier(); + let out_fmt = out.fmt_left(); writeln!( f, - "{out} = reinterpret_cast<{item_out_original}&>({out_tmp});\n" + "{out_fmt} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n" ) } else { write_op(index, elem, &input, &out_optimized) @@ -159,6 +160,53 @@ function!(Tanh, "tanh", false); function!(Erf, "erf", false); function!(Abs, "abs", false); +fn zero_extend(input: impl Component) -> String { + match input.elem() { + Elem::I8 => format!("{}({}({input}))", Elem::::U32, Elem::::U8), + Elem::I16 => format!("{}({}({input}))", Elem::::U32, Elem::::U16), + Elem::U8 => format!("{}({input})", Elem::::U32), + Elem::U16 => format!("{}({input})", Elem::::U32), + _ => unreachable!("zero extend only supports integer < 32 bits"), + } +} + +pub struct CountBits; + +impl Unary for CountBits { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + _elem: Elem, + ) -> std::fmt::Result { + match input.elem() { + Elem::I32 | Elem::U32 => write!(f, "__popc({input})"), + Elem::I64 | Elem::U64 => write!(f, "__popcll({input})"), + _ => write!(f, "__popc({})", zero_extend(input)), + } + } +} + +pub struct ReverseBits; + +impl Unary for ReverseBits { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + elem: Elem, + ) -> std::fmt::Result { + match elem { + Elem::I32 | Elem::U32 => write!(f, "__brev({input})"), + Elem::I64 | Elem::U64 => write!(f, "__brevll({input})"), + _ => write!( + f, + "{elem}(__brev({}) >> {})", + zero_extend(input), + (size_of::() - elem.size()) * 8 + ), + } + } +} + pub struct Not; impl Unary for Not { diff --git a/crates/cubecl-cpp/src/shared/warp.rs b/crates/cubecl-cpp/src/shared/warp.rs index 4023d3e5e..3043f5e1e 100644 --- a/crates/cubecl-cpp/src/shared/warp.rs +++ b/crates/cubecl-cpp/src/shared/warp.rs @@ -1,8 +1,8 @@ use std::fmt::Display; -use crate::shared::{Component, Elem}; +use crate::shared::{Component, Elem, FmtLeft}; -use super::{Dialect, IndexedVariable, Variable}; +use super::{Dialect, Item, Variable}; #[derive(Clone, Debug)] pub enum WarpInstruction { @@ -47,6 +47,9 @@ impl Display for WarpInstruction { WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="), WarpInstruction::ReduceMax { input, out } => reduce_comparison(f, input, out, "max"), WarpInstruction::ReduceMin { input, out } => reduce_comparison(f, input, out, "min"), + WarpInstruction::All { input, out } => reduce_quantifier(f, input, out, D::warp_all), + WarpInstruction::Any { input, out } => reduce_quantifier(f, input, out, D::warp_any), + WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id), WarpInstruction::Elect { out } => write!( f, " @@ -55,27 +58,6 @@ unsigned int leader = __ffs(mask) - 1; {out} = threadIdx.x % warpSize == leader; " ), - WarpInstruction::All { input, out } => reduce_quantifier(f, input, out, D::warp_all), - WarpInstruction::Any { input, out } => reduce_quantifier(f, input, out, D::warp_any), - WarpInstruction::Broadcast { input, id, out } => { - let input_optimized = input.optimized(); - let out_optimized = out.optimized(); - for k in 0..out_optimized.item().vectorization { - let __shfl = D::warp_shuffle(&input_optimized.index(k), id); - let indexed = out_optimized.index(k); - write!( - f, - " - {{ - for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{ - {indexed} = {__shfl}; - }} - }} - " - )?; - } - Ok(()) - } } } } @@ -86,30 +68,14 @@ fn reduce_operator( out: &Variable, op: &str, ) -> core::fmt::Result { - write!( - f, - " - {out} = {input}; - " - )?; + let in_optimized = input.optimized(); + let acc_item = in_optimized.item(); - let optimized = out.optimized(); - - for k in 0..optimized.item().vectorization { - let indexed = optimized.index(k); - let __shfl_xor = D::warp_shuffle_xor(&indexed); - write!( - f, - " - {{ - for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{ - {indexed} {op} {__shfl_xor}; - }} - }} - " - )?; - } - Ok(()) + reduce_with_loop(f, input, out, acc_item, |acc, index| { + let acc_indexed = maybe_index(acc, index); + let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset"); + format!("{acc_indexed} {op} {shfl_xor};") + }) } fn reduce_comparison( @@ -118,64 +84,89 @@ fn reduce_comparison( out: &Variable, cmp: &str, ) -> core::fmt::Result { - write!( - f, - " - {out} = {input}; - " - )?; - - let optimized = out.optimized(); - - let instruction = match optimized.elem() { + let in_optimized = input.optimized(); + let acc_item = in_optimized.item(); + let instruction = match in_optimized.elem() { Elem::F16 | Elem::BF16 => format!("__h{cmp}"), Elem::F162 | Elem::BF162 => format!("__h{cmp}2"), _ => cmp.to_string(), }; - for k in 0..optimized.item().vectorization { - let indexed = optimized.index(k); - let __shfl_xor = D::warp_shuffle_xor(&indexed); - write!( - f, - " - {{ - for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ - {indexed} = {instruction}({indexed}, {__shfl_xor}); - }} - }} - " - )?; - } - Ok(()) + reduce_with_loop(f, input, out, acc_item, |acc, index| { + let acc_indexed = maybe_index(acc, index); + let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset"); + format!("{acc_indexed} = {instruction}({acc_indexed}, {shfl_xor});") + }) } -fn reduce_quantifier) -> String>( +fn reduce_broadcast( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, - quantifier: Q, + id: &Variable, ) -> core::fmt::Result { - write!( + let rhs = (0..input.item().vectorization) + .map(|k| D::warp_shuffle(&format!("{}", input.index(k)), &format!("{id}"))) + .collect::>() + .join(","); + let out_fmt = out.fmt_left(); + writeln!(f, "{out_fmt} = {{ {rhs} }};") +} + +fn reduce_with_loop, usize) -> String>( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + acc_item: Item, + instruction: I, +) -> core::fmt::Result { + let acc = Variable::Named { + name: "acc", + item: acc_item, + }; + + writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?; + writeln!(f, " {} {} = {};", acc_item, acc, cast(input, acc_item))?; + writeln!( f, - " - {out} = {input}; - " + " for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{" )?; - let optimized = out.optimized(); - for k in 0..optimized.item().vectorization { - let indexed = optimized.index(k); - let __all = quantifier(&indexed); - write!( - f, - " - {{ - for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{ - {indexed} = {__all}; - }} - }} - " - )?; + for k in 0..acc_item.vectorization { + writeln!(f, " {}", instruction(&acc, k))?; + } + writeln!(f, " }};")?; + writeln!(f, " return {};", cast(&acc, out.item()))?; + writeln!(f, "}};")?; + writeln!(f, "{} = plane_{}();", out.fmt_left(), out) +} + +fn reduce_quantifier String>( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + quantifier: Q, +) -> core::fmt::Result { + let rhs = (0..input.item().vectorization) + .map(|k| quantifier(&format!("{}", input.index(k)))) + .collect::>() + .join(","); + let out_fmt = out.fmt_left(); + writeln!(f, "{out_fmt} = {{ {rhs} }};") +} + +fn cast(input: &Variable, target: Item) -> String { + if target != input.item() { + let qualifier = input.const_qualifier(); + format!("reinterpret_cast<{}{}&>({})", target, qualifier, input) + } else { + format!("{}", input) + } +} + +fn maybe_index(var: &Variable, k: usize) -> String { + if var.item().vectorization > 1 { + format!("{var}.i_{k}") + } else { + format!("{var}") } - Ok(()) } diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index c5e5a02bc..06f2a8162 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -469,6 +469,10 @@ fn cuda_path() -> Option { #[cfg(target_os = "linux")] { + // If it is installed as part of the distribution + if std::fs::metadata("/usr/bin/nvcc").is_ok() { + return Some(PathBuf::from("/usr")); + } return Some(PathBuf::from("/usr/local/cuda")); } diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index fd19e100d..8b1e36e0b 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -15,9 +15,7 @@ mod tests { pub use half::{bf16, f16}; cubecl_core::testgen_all!(f32: [f16, bf16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); - cubecl_linalg::testgen_matmul_cmma!([f16]); - cubecl_linalg::testgen_matmul_plane_mma!([f16], f16); - cubecl_linalg::testgen_matmul_plane_mma!([f16], f32); + cubecl_linalg::testgen_matmul_accelerated!([f16]); cubecl_linalg::testgen_matmul_simple!([f16, bf16, f32]); cubecl_linalg::testgen_matmul_tiling2d!([f16, bf16, f32]); cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]); diff --git a/crates/cubecl-cuda/tests/common.rs b/crates/cubecl-cuda/tests/common.rs deleted file mode 100644 index 537866bf5..000000000 --- a/crates/cubecl-cuda/tests/common.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::{io::Write, num::NonZero, process::Command}; - -use cubecl_core::{ - prelude::{ArrayCompilationArg, TensorCompilationArg}, - Compiler, CubeDim, ExecutionMode, Kernel, KernelSettings, Runtime, -}; -use cubecl_cuda::CudaRuntime; - -pub fn settings() -> KernelSettings { - KernelSettings::default().cube_dim(CubeDim::default()) -} - -#[allow(unused)] -pub fn tensor() -> TensorCompilationArg { - TensorCompilationArg { - inplace: None, - vectorisation: NonZero::new(1), - } -} - -#[allow(unused)] -pub fn tensor_vec(vec: u8) -> TensorCompilationArg { - TensorCompilationArg { - inplace: None, - vectorisation: NonZero::new(vec), - } -} - -#[allow(unused)] -pub fn array() -> ArrayCompilationArg { - ArrayCompilationArg { - inplace: None, - vectorisation: NonZero::new(1), - } -} - -pub fn compile(kernel: impl Kernel) -> String { - let kernel = <::Compiler as Compiler>::compile( - kernel.define(), - &Default::default(), - ExecutionMode::Checked, - ) - .to_string(); - format_cpp_code(&kernel).unwrap() -} - -/// Format C++ code, useful when debugging. -fn format_cpp_code(code: &str) -> Result { - let mut child = Command::new("clang-format") - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .spawn()?; - - { - let stdin = child.stdin.as_mut().expect("Failed to open stdin"); - stdin.write_all(code.as_bytes())?; - } - - let output = child.wait_with_output()?; - - if output.status.success() { - Ok(String::from_utf8_lossy(&output.stdout).into_owned()) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "clang-format failed", - )) - } -} diff --git a/crates/cubecl-cuda/tests/constant_array.cu b/crates/cubecl-cuda/tests/constant_array.cu deleted file mode 100644 index 5f40ca3a3..000000000 --- a/crates/cubecl-cuda/tests/constant_array.cu +++ /dev/null @@ -1,37 +0,0 @@ -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint; -typedef unsigned long long int uint64; -typedef long long int int64; - -extern "C" __global__ void constant_array_kernel_f32(float output_0[], - uint info[]) { - - int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y, - blockIdx.z * blockDim.z + threadIdx.z); - - uint idxGlobal = - (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + - (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x; - const float arrays_0[3] = { - float(3), - float(5), - float(1), - }; - uint l_0_0; - bool l_0_1; - float l_0_2; - l_0_0 = info[uint(1)]; - l_0_1 = idxGlobal < l_0_0; - if (l_0_1) { - l_0_2 = arrays_0[idxGlobal]; - uint l_1_0; - bool l_1_1; - l_1_0 = info[uint(0)]; - l_1_1 = idxGlobal < l_1_0; - if (l_1_1) { - output_0[idxGlobal] = l_0_2; - } - } -} diff --git a/crates/cubecl-cuda/tests/main.rs b/crates/cubecl-cuda/tests/main.rs deleted file mode 100644 index 404f4c2d2..000000000 --- a/crates/cubecl-cuda/tests/main.rs +++ /dev/null @@ -1,136 +0,0 @@ -use common::*; -use constant_array_kernel::ConstantArrayKernel; -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_cuda::CudaRuntime; -use execute_unary_kernel::ExecuteUnaryKernel; -use half::bf16; -use kernel_sum::KernelSum; -use naming_kernel::NamingKernel; -use pretty_assertions::assert_eq; -use sequence_for_loop_kernel::SequenceForLoopKernel; -use slice_assign_kernel::SliceAssignKernel; - -mod common; - -#[cube(launch_unchecked, create_dummy_kernel)] -pub fn slice_assign_kernel(input: &Tensor, output: &mut Tensor) { - if UNIT_POS == 0 { - let mut slice_1 = output.slice_mut(2, 3); - slice_1[0] = input[0]; - } -} - -#[test] -pub fn slice_assign() { - let kernel = SliceAssignKernel::::new(settings(), tensor(), tensor()); - let expected = include_str!("slice_assign.cu").replace("\r\n", "\n"); - let expected = expected.trim(); - assert_eq!(compile(kernel), expected); -} - -#[cube(launch, create_dummy_kernel)] -pub fn kernel_sum(output: &mut Tensor) { - let val = output[UNIT_POS]; - let val2 = cubecl_core::prelude::plane_sum(val); - - if UNIT_POS == 0 { - output[0] = val2; - } -} - -#[test] -pub fn plane_sum() { - let kernel = KernelSum::::new(settings(), tensor()); - - let expected = include_str!("plane_sum.cu").replace("\r\n", "\n"); - let expected = expected.trim(); - assert_eq!(compile(kernel), expected); -} - -#[cube(launch, create_dummy_kernel)] -pub fn sequence_for_loop_kernel(output: &mut Array) { - if UNIT_POS != 0 { - return; - } - - let mut sequence = Sequence::::new(); - sequence.push(1.0); - sequence.push(4.0); - - for value in sequence { - output[0] += value; - } -} - -#[test] -pub fn sequence_for_loop() { - let kernel = SequenceForLoopKernel::::new(settings(), array()); - let expected = include_str!("sequence_for_loop.cu").replace("\r\n", "\n"); - let expected = expected.trim(); - assert_eq!(compile(kernel), expected); -} - -#[cube(launch, create_dummy_kernel)] -fn execute_unary_kernel(lhs: &Tensor, rhs: &Tensor, out: &mut Tensor) { - if ABSOLUTE_POS < out.len() { - for i in 0..256u32 { - if i % 2 == 0 { - out[ABSOLUTE_POS] -= F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); - } else { - out[ABSOLUTE_POS] += F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); - } - } - } -} - -#[test] -pub fn unary_bench() { - let kernel = ExecuteUnaryKernel::::new( - settings(), - tensor_vec(4), - tensor_vec(4), - tensor_vec(4), - ); - let expected = include_str!("unary_bench.cu").replace("\r\n", "\n"); - let expected = expected.trim(); - - assert_eq!(compile(kernel), expected); -} - -#[cube(launch, create_dummy_kernel)] -fn constant_array_kernel(out: &mut Tensor, #[comptime] data: Vec) { - let array = Array::::from_data(data); - - if ABSOLUTE_POS < out.len() { - out[ABSOLUTE_POS] = array[ABSOLUTE_POS]; - } -} - -#[test] -pub fn constant_array() { - let data: Vec = vec![3, 5, 1]; - - let kernel = ConstantArrayKernel::::new(settings(), tensor(), data); - let expected = include_str!("constant_array.cu").replace("\r\n", "\n"); - let expected = expected.trim(); - assert_eq!(compile(kernel), expected); -} - -// This kernel just exists to have a few generics in order to observe -// that the generics get propagated into the WGSL kernel name -#[allow(clippy::extra_unused_type_parameters)] -#[cube(launch, create_dummy_kernel)] -fn naming_kernel(out: &mut Array) { - if ABSOLUTE_POS < out.len() { - out[ABSOLUTE_POS] = F1::from_int(0); - } -} - -#[test] -pub fn naming() { - let kernel = NamingKernel::::new(settings(), array()); - let expected = include_str!("naming.cu").replace("\r\n", "\n"); - let expected = expected.trim(); - assert_eq!(compile(kernel), expected); -} diff --git a/crates/cubecl-cuda/tests/naming.cu b/crates/cubecl-cuda/tests/naming.cu deleted file mode 100644 index 65d63ea5b..000000000 --- a/crates/cubecl-cuda/tests/naming.cu +++ /dev/null @@ -1,30 +0,0 @@ -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint; -typedef unsigned long long int uint64; -typedef long long int int64; - -extern "C" __global__ void naming_kernel_f32_u8_bf16_i64(float output_0[], - uint info[]) { - - int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y, - blockIdx.z * blockDim.z + threadIdx.z); - - uint idxGlobal = - (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + - (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x; - uint l_0_0; - bool l_0_1; - l_0_0 = info[uint(1)]; - l_0_1 = idxGlobal < l_0_0; - if (l_0_1) { - uint l_1_0; - bool l_1_1; - l_1_0 = info[uint(0)]; - l_1_1 = idxGlobal < l_1_0; - if (l_1_1) { - output_0[idxGlobal] = float(0.0); - } - } -} diff --git a/crates/cubecl-cuda/tests/plane_sum.cu b/crates/cubecl-cuda/tests/plane_sum.cu deleted file mode 100644 index 604851453..000000000 --- a/crates/cubecl-cuda/tests/plane_sum.cu +++ /dev/null @@ -1,41 +0,0 @@ -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint; -typedef unsigned long long int uint64; -typedef long long int int64; - -extern "C" __global__ void kernel_sum(float output_0[], uint info[]) { - - int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * (blockDim.x * blockDim.y); - - int warpSizeChecked = min(warpSize, blockDim.x * blockDim.y * blockDim.z); - float l_0_0; - float l_0_1; - bool l_0_2; - uint l_0_3; - bool l_0_4; - float l_0_5; - l_0_3 = info[uint(0)]; - l_0_4 = threadIdxGlobal < l_0_3; - l_0_5 = output_0[threadIdxGlobal]; - l_0_0 = (l_0_4) ? l_0_5 : float(0.0); - - l_0_1 = l_0_0; - - { - for (int offset = 1; offset < warpSizeChecked; offset *= 2) { - l_0_1 += __shfl_xor_sync(-1, l_0_1, offset); - } - } - l_0_2 = threadIdxGlobal == uint(0); - if (l_0_2) { - uint l_1_0; - bool l_1_1; - l_1_0 = info[uint(0)]; - l_1_1 = uint(0) < l_1_0; - if (l_1_1) { - output_0[uint(0)] = l_0_1; - } - } -} diff --git a/crates/cubecl-cuda/tests/sequence_for_loop.cu b/crates/cubecl-cuda/tests/sequence_for_loop.cu deleted file mode 100644 index a54f52f5c..000000000 --- a/crates/cubecl-cuda/tests/sequence_for_loop.cu +++ /dev/null @@ -1,48 +0,0 @@ -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint; -typedef unsigned long long int uint64; -typedef long long int int64; - -extern "C" __global__ void sequence_for_loop_kernel(float output_0[], - uint info[]) { - - int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * (blockDim.x * blockDim.y); - bool l_0_0; - float l_0_1; - l_0_0 = threadIdxGlobal != uint(0); - if (l_0_0) { - return; - } - uint l_0_2; - bool l_0_3; - float l_0_4; - l_0_2 = info[uint(0)]; - l_0_3 = uint(0) < l_0_2; - l_0_4 = output_0[uint(0)]; - l_0_1 = (l_0_3) ? l_0_4 : float(0.0); - l_0_1 = l_0_1 + float(1.0); - uint l_0_5; - bool l_0_6; - l_0_5 = info[uint(0)]; - l_0_6 = uint(0) < l_0_5; - if (l_0_6) { - output_0[uint(0)] = l_0_1; - } - uint l_0_7; - bool l_0_8; - float l_0_9; - l_0_7 = info[uint(0)]; - l_0_8 = uint(0) < l_0_7; - l_0_9 = output_0[uint(0)]; - l_0_1 = (l_0_8) ? l_0_9 : float(0.0); - l_0_1 = l_0_1 + float(4.0); - uint l_0_10; - bool l_0_11; - l_0_10 = info[uint(0)]; - l_0_11 = uint(0) < l_0_10; - if (l_0_11) { - output_0[uint(0)] = l_0_1; - } -} diff --git a/crates/cubecl-cuda/tests/slice_assign.cu b/crates/cubecl-cuda/tests/slice_assign.cu deleted file mode 100644 index 2fbbc707f..000000000 --- a/crates/cubecl-cuda/tests/slice_assign.cu +++ /dev/null @@ -1,35 +0,0 @@ -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint; -typedef unsigned long long int uint64; -typedef long long int int64; - -extern "C" __global__ void slice_assign_kernel(float input_0[], - float output_0[], uint info[]) { - - int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * (blockDim.x * blockDim.y); - bool l_0_0; - float l_0_1; - l_0_0 = threadIdxGlobal == uint(0); - if (l_0_0) { - uint l_1_0; - l_1_0 = info[uint(1)]; - const uint slice_1_0_length = min(l_1_0, uint(3)) - uint(2); - float *slice_1_0 = output_0 + uint(2); - uint l_1_1; - bool l_1_2; - float l_1_3; - l_1_1 = info[uint(0)]; - l_1_2 = uint(0) < l_1_1; - l_1_3 = input_0[uint(0)]; - l_0_1 = (l_1_2) ? l_1_3 : float(0.0); - uint l_1_4; - bool l_1_5; - l_1_4 = slice_1_0_length; - l_1_5 = uint(0) < l_1_4; - if (l_1_5) { - slice_1_0[uint(0)] = l_0_1; - } - } -} diff --git a/crates/cubecl-cuda/tests/unary_bench.cu b/crates/cubecl-cuda/tests/unary_bench.cu deleted file mode 100644 index c2a0fde2f..000000000 --- a/crates/cubecl-cuda/tests/unary_bench.cu +++ /dev/null @@ -1,165 +0,0 @@ -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint; -typedef unsigned long long int uint64; -typedef long long int int64; - -struct __align__(16) float_4 { - float i_0; - float i_1; - float i_2; - float i_3; -}; - -extern "C" __global__ void execute_unary_kernel_f32(float_4 input_0[], - float_4 input_1[], - float_4 output_0[], - uint info[]) { - - int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y, - blockIdx.z * blockDim.z + threadIdx.z); - - uint idxGlobal = - (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + - (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x; - uint l_0_0; - bool l_0_1; - bool l_0_2; - float_4 l_0_3; - float_4 l_0_4; - l_0_0 = info[uint(5)]; - l_0_1 = idxGlobal < l_0_0; - if (l_0_1) { - - for (uint l_2_0 = uint(0); l_2_0 < uint(256); ++l_2_0) { - l_0_0 = l_2_0 % uint(2); - l_0_2 = l_0_0 == uint(0); - if (l_0_2) { - uint l_3_0; - bool l_3_1; - float_4 l_3_2; - l_3_0 = info[uint(0)]; - l_3_1 = idxGlobal < l_3_0; - l_3_2 = input_0[idxGlobal]; - l_0_3 = float_4{ - (l_3_1) ? l_3_2.i_0 : float(0.0), - (l_3_1) ? l_3_2.i_1 : float(0.0), - (l_3_1) ? l_3_2.i_2 : float(0.0), - (l_3_1) ? l_3_2.i_3 : float(0.0), - }; - uint l_3_3; - bool l_3_4; - float_4 l_3_5; - l_3_3 = info[uint(1)]; - l_3_4 = idxGlobal < l_3_3; - l_3_5 = input_1[idxGlobal]; - l_0_4 = float_4{ - (l_3_4) ? l_3_5.i_0 : float(0.0), - (l_3_4) ? l_3_5.i_1 : float(0.0), - (l_3_4) ? l_3_5.i_2 : float(0.0), - (l_3_4) ? l_3_5.i_3 : float(0.0), - }; - l_0_4 = float_4{ - l_0_3.i_0 * l_0_4.i_0, - l_0_3.i_1 * l_0_4.i_1, - l_0_3.i_2 * l_0_4.i_2, - l_0_3.i_3 * l_0_4.i_3, - }; - l_0_4 = float_4{ - cos(l_0_4.i_0), - cos(l_0_4.i_1), - cos(l_0_4.i_2), - cos(l_0_4.i_3), - }; - uint l_3_6; - bool l_3_7; - float_4 l_3_8; - l_3_6 = info[uint(2)]; - l_3_7 = idxGlobal < l_3_6; - l_3_8 = output_0[idxGlobal]; - l_0_3 = float_4{ - (l_3_7) ? l_3_8.i_0 : float(0.0), - (l_3_7) ? l_3_8.i_1 : float(0.0), - (l_3_7) ? l_3_8.i_2 : float(0.0), - (l_3_7) ? l_3_8.i_3 : float(0.0), - }; - l_0_3 = float_4{ - l_0_3.i_0 - l_0_4.i_0, - l_0_3.i_1 - l_0_4.i_1, - l_0_3.i_2 - l_0_4.i_2, - l_0_3.i_3 - l_0_4.i_3, - }; - uint l_3_9; - bool l_3_10; - l_3_9 = info[uint(2)]; - l_3_10 = idxGlobal < l_3_9; - if (l_3_10) { - output_0[idxGlobal] = l_0_3; - } - } else { - uint l_3_0; - bool l_3_1; - float_4 l_3_2; - l_3_0 = info[uint(0)]; - l_3_1 = idxGlobal < l_3_0; - l_3_2 = input_0[idxGlobal]; - l_0_4 = float_4{ - (l_3_1) ? l_3_2.i_0 : float(0.0), - (l_3_1) ? l_3_2.i_1 : float(0.0), - (l_3_1) ? l_3_2.i_2 : float(0.0), - (l_3_1) ? l_3_2.i_3 : float(0.0), - }; - uint l_3_3; - bool l_3_4; - float_4 l_3_5; - l_3_3 = info[uint(1)]; - l_3_4 = idxGlobal < l_3_3; - l_3_5 = input_1[idxGlobal]; - l_0_3 = float_4{ - (l_3_4) ? l_3_5.i_0 : float(0.0), - (l_3_4) ? l_3_5.i_1 : float(0.0), - (l_3_4) ? l_3_5.i_2 : float(0.0), - (l_3_4) ? l_3_5.i_3 : float(0.0), - }; - l_0_4 = float_4{ - l_0_4.i_0 * l_0_3.i_0, - l_0_4.i_1 * l_0_3.i_1, - l_0_4.i_2 * l_0_3.i_2, - l_0_4.i_3 * l_0_3.i_3, - }; - l_0_4 = float_4{ - cos(l_0_4.i_0), - cos(l_0_4.i_1), - cos(l_0_4.i_2), - cos(l_0_4.i_3), - }; - uint l_3_6; - bool l_3_7; - float_4 l_3_8; - l_3_6 = info[uint(2)]; - l_3_7 = idxGlobal < l_3_6; - l_3_8 = output_0[idxGlobal]; - l_0_3 = float_4{ - (l_3_7) ? l_3_8.i_0 : float(0.0), - (l_3_7) ? l_3_8.i_1 : float(0.0), - (l_3_7) ? l_3_8.i_2 : float(0.0), - (l_3_7) ? l_3_8.i_3 : float(0.0), - }; - l_0_3 = float_4{ - l_0_3.i_0 + l_0_4.i_0, - l_0_3.i_1 + l_0_4.i_1, - l_0_3.i_2 + l_0_4.i_2, - l_0_3.i_3 + l_0_4.i_3, - }; - uint l_3_9; - bool l_3_10; - l_3_9 = info[uint(2)]; - l_3_10 = idxGlobal < l_3_9; - if (l_3_10) { - output_0[idxGlobal] = l_0_3; - } - } - } - } -} diff --git a/crates/cubecl-hip/src/lib.rs b/crates/cubecl-hip/src/lib.rs index 97b97d607..86ef5d9c4 100644 --- a/crates/cubecl-hip/src/lib.rs +++ b/crates/cubecl-hip/src/lib.rs @@ -27,6 +27,7 @@ mod tests { pub type TestRuntime = crate::HipRuntime; cubecl_core::testgen_all!(); - cubecl_linalg::testgen_matmul_cmma!(); + cubecl_linalg::testgen_matmul_plane!([f32]); + cubecl_linalg::testgen_matmul_accelerated!([f32]); cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]); } diff --git a/crates/cubecl-linalg/Cargo.toml b/crates/cubecl-linalg/Cargo.toml index facffd87d..29a61d9b5 100644 --- a/crates/cubecl-linalg/Cargo.toml +++ b/crates/cubecl-linalg/Cargo.toml @@ -24,6 +24,7 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", default-features = f cubecl-runtime = { path = "../cubecl-runtime", version = "0.4.0", default-features = false } half = { workspace = true, features = ["bytemuck"] } pretty_assertions = { workspace = true, optional = true } +serde = { workspace = true } [dev-dependencies] trybuild = "1" diff --git a/crates/cubecl-linalg/src/matmul/base.rs b/crates/cubecl-linalg/src/matmul/base.rs index 7a7073ef8..0599993e1 100644 --- a/crates/cubecl-linalg/src/matmul/base.rs +++ b/crates/cubecl-linalg/src/matmul/base.rs @@ -6,15 +6,23 @@ use cubecl_core::{ use crate::tensor::TensorHandle; -use super::kernels::{ - matmul, simple, - tiling2d::{self, Tiling2dConfig}, - MatmulLaunchError, +use super::{ + components::tile::accelerated::Accelerated, + kernels::{ + matmul::{self, PipelinedSelector, SpecializedSelector, StandardSelector}, + simple, + tiling2d::{self, Tiling2dConfig}, + MatmulLaunchError, + }, }; #[derive(Debug, Clone, Default)] pub enum Strategy { - Accelerated, + Standard, + Pipelined, + Specialized, + #[cfg(any(test, feature = "export_tests"))] + // Very slow, only use for testing. PlaneMma, Simple, Tiling2D(Tiling2dConfig), @@ -46,8 +54,21 @@ pub fn launch_ref( out: &TensorHandleRef, ) -> Result<(), MatmulLaunchError> { match strategy { - Strategy::Accelerated => matmul::launch_ref::(client, lhs, rhs, out, false), - Strategy::PlaneMma => matmul::launch_ref::(client, lhs, rhs, out, true), + Strategy::Standard => { + matmul::launch_ref::>(client, lhs, rhs, out) + } + Strategy::Pipelined => { + matmul::launch_ref::>(client, lhs, rhs, out) + } + Strategy::Specialized => { + matmul::launch_ref::>(client, lhs, rhs, out) + } + #[cfg(any(test, feature = "export_tests"))] + Strategy::PlaneMma => { + matmul::launch_ref::>( + client, lhs, rhs, out, + ) + } Strategy::Tiling2D(config) => { tiling2d::launch_ref::(client, lhs, rhs, out, config.clone()); Ok(()) @@ -57,7 +78,9 @@ pub fn launch_ref( Ok(()) } Strategy::Auto => { - if let Err(err) = matmul::launch_ref::(client, lhs, rhs, out, false) { + if let Err(err) = + matmul::launch_ref::>(client, lhs, rhs, out) + { match err { super::kernels::MatmulLaunchError::Unavailable(_) => { tiling2d::launch_ref::( diff --git a/crates/cubecl-linalg/src/matmul/components/base.rs b/crates/cubecl-linalg/src/matmul/components/base.rs index b2fe21404..52922c5a0 100644 --- a/crates/cubecl-linalg/src/matmul/components/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/base.rs @@ -1,45 +1,34 @@ +use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::{InputRuntimeArg, MatmulSpec, OutputRuntimeArg}; -use crate::matmul::kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}; +use super::{InputRuntimeArg, MatmulConfigFactory, MatmulSpec, OutputRuntimeArg}; -use super::{config::MatmulConfig, MatmulProblem}; - -/// Provides configuration for a matmul kernel at any level -pub trait MatmulKernel { - /// Configuration tailored to the matmul implementation - type Config: MatmulConfig; - - /// Asserts that the configuration for this matmul will lead to a valid computation - fn check_config(config: Self::Config); - - /// Checks if the client can handle the features used in this computation - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError>; +#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct MatmulSize { + pub m: u32, + pub n: u32, + pub k: u32, +} - /// Create config for this matmul, given launch information - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config; +pub struct MatmulSelection { + pub tile: MatmulSize, + pub num_stagess: MatmulSize, + pub plane_dim: u32, } /// Provides launch entry point to solve a matmul -pub trait MatmulLaunch: MatmulKernel { +pub trait MatmulLaunch: MatmulConfigFactory { /// Entry point /// /// # Safety /// /// Out-of-bounds can happen - unsafe fn launch_unchecked<'a, R: Runtime>( + unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( client: &ComputeClient<::Server, ::Channel>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, - config: ::Config, + config: ::Config, ); } diff --git a/crates/cubecl-linalg/src/matmul/components/batch/base.rs b/crates/cubecl-linalg/src/matmul/components/batch/base.rs index a98227534..a4ac9b207 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/base.rs @@ -2,8 +2,14 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use crate::matmul::components::global::args::{self, MatmulArgs, TensorInput, TensorOutput}; -use crate::matmul::components::{batch, InputArg, MatmulSpec, OutputArg}; +use crate::matmul::components::MatmulPrecision; use crate::matmul::components::{config::MatmulConfig, global, Ident, MatmulLaunch, StageDim}; +use crate::tensor::{ReadWrite, VirtualTensor}; + +/// A family of [matmuls](BatchMatmul) working with any [precision](MatmulPrecision). +pub trait BatchMatmulFamily: 'static + Send + Sync + MatmulLaunch { + type Matmul: BatchMatmul; +} #[cube] /// Provides matrix multiplication operations at the batch level. @@ -23,20 +29,22 @@ use crate::matmul::components::{config::MatmulConfig, global, Ident, MatmulLaunc /// It is not assumed that the matmul's dimensions match its inputs dimensions perfectly. /// It is therefore important to use an underlying global matmul that performs check bounds, /// and to not launch more Cubes than necessary. -pub trait Matmul: 'static + Send + Sync + MatmulLaunch { +pub trait BatchMatmul: 'static + Send + Sync { + type Config: BatchConfig; + /// Performs batchwise matrix multiplication over tensors. fn execute( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, #[comptime] config: Self::Config, ); } -/// Configuration for the Batch matmul (BMM) level -pub trait Config: MatmulConfig { +/// Configuration for the [batch matmul](BatchMatmul) level. +pub trait BatchConfig: MatmulConfig { /// Underlying Global matmul config - type GmmConfig: global::Config; + type GmmConfig: global::GlobalConfig; /// Convert itself to the underlying global matmul config fn to_gmm_config(&self) -> Self::GmmConfig; @@ -54,17 +62,30 @@ pub trait Config: MatmulConfig { fn max_batches(&self) -> u32; } +type Input = ::Input; +type Output = ::Output; + #[cube(launch_unchecked)] -pub(crate) fn batch_matmul>( - inputs: &InputArg, - output: &mut OutputArg, +pub(crate) fn matmul< + EG: Numeric, + ES: Numeric, + EA: Numeric, + Args: MatmulArgs, + BMM: BatchMatmulFamily, +>( + inputs: &Input, + output: &mut Output, #[comptime] config: BMM::Config, ) { - let mut state = MS::Args::init_state(inputs, output); + let mut state = Args::init_state(inputs, output); + + let lhs = TensorInput::::new(&state, args::TensorInputIdent::Lhs); + let rhs = TensorInput::::new(&state, args::TensorInputIdent::Rhs); + let mut out = TensorOutput::::new(&mut state); - let lhs = TensorInput::::new(&state, args::TensorInputIdent::Lhs); - let rhs = TensorInput::::new(&state, args::TensorInputIdent::Rhs); - let out = TensorOutput::::new(&mut state); + let lhs = VirtualTensor::::new::>(&lhs); + let rhs = VirtualTensor::::new::>(&rhs); + let out = VirtualTensor::::new::>(&mut out); - BMM::execute(lhs, rhs, out, config); + BMM::Matmul::<(EG, ES, EA)>::execute(lhs, rhs, out, config); } diff --git a/crates/cubecl-linalg/src/matmul/components/batch/one_to_many.rs b/crates/cubecl-linalg/src/matmul/components/batch/one_to_many.rs index e20418c7e..d4a32eb4d 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/one_to_many.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/one_to_many.rs @@ -1,38 +1,113 @@ use std::marker::PhantomData; use crate::matmul::components::batch::span::{Span, SpanDim, SpanMatmul}; -use crate::matmul::components::global::args::{TensorInput, TensorOutput}; +use crate::matmul::components::global::GlobalMatmulFamily; use crate::matmul::components::{ - batch, config::MatmulConfig, global, Ident, MatmulKernel, MatmulLaunch, StageDim, + batch, config::MatmulConfig, global, Ident, MatmulConfigFactory, MatmulLaunch, StageDim, +}; +use crate::matmul::components::{ + InputRuntimeArg, InvalidConfigError, MatmulPrecision, MatmulProblem, MatmulSpec, + OutputRuntimeArg, }; -use crate::matmul::components::{InputRuntimeArg, MatmulProblem, MatmulSpec, OutputRuntimeArg}; use crate::matmul::kernels::matmul::AdvancedConfig; use crate::matmul::kernels::MatmulAvailabilityError; +use crate::tensor::{ReadWrite, VirtualTensor}; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::{Config as _, CubeDispatch}; +use super::{BatchConfig as _, BatchMatmulFamily, CubeDispatch}; + +pub struct OneToManyMatmulFamily { + _gmm: PhantomData, + _s: PhantomData, + _c: PhantomData, +} + +impl BatchMatmulFamily + for OneToManyMatmulFamily +{ + type Matmul = OneToManyMatmul, S, C>; +} + +impl MatmulConfigFactory + for OneToManyMatmulFamily +{ + type Config = Config; + type Input = GMM::Input; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + GMM::check_config(&config.to_gmm_config()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + GMM::check_availability::(client, &config.gmm_config) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let gmm_config = GMM::make_config(input, problem, cube_dim, cube_count, advanced_config); + let cube_count = if let CubeCount::Static(x, y, z) = cube_count { + (*x, *y, *z) + } else { + panic!("Dynamic cube count unsupported") + }; + + Config::new(gmm_config, cube_count) + } +} + +impl MatmulLaunch + for OneToManyMatmulFamily +{ + unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( + client: &ComputeClient<::Server, ::Channel>, + cube_dim: CubeDim, + cube_count: CubeCount, + input: InputRuntimeArg<'a, MS, R>, + output: OutputRuntimeArg<'a, MS, R>, + config: Self::Config, + ) { + super::matmul::launch_unchecked::( + client, cube_count, cube_dim, input, output, config, + ); + } +} /// Executes matrix multiplication at the batch level, /// assigning each cube to handle multiple global matmuls. /// /// The algorithm supports any number of cubes, /// looping as needed to process all data. -pub struct Matmul, S: SpanMatmul, C: CubeDispatch> { - _ms: PhantomData, +pub struct OneToManyMatmul< + MP: MatmulPrecision, + GMM: global::GlobalMatmul, + S: SpanMatmul, + C: CubeDispatch, +> { + _mp: PhantomData, _gmm: PhantomData, _s: PhantomData, _c: PhantomData, } #[cube] -impl, S: SpanMatmul, C: CubeDispatch> batch::Matmul - for Matmul +impl, S: SpanMatmul, C: CubeDispatch> + batch::BatchMatmul for OneToManyMatmul { + type Config = Config; + fn execute( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, #[comptime] config: Self::Config, ) { let rank = out.rank(); @@ -65,69 +140,19 @@ impl, S: SpanMatmul, C: CubeDispatch> ba let gmm_config = config.to_gmm_config(); let acc = GMM::init_accumulator(gmm_config); - S::execute::(lhs, rhs, out, span, acc, k_range, gmm_config); - } -} - -impl, S: SpanMatmul, C: CubeDispatch> MatmulKernel - for Matmul -{ - type Config = Config; - - fn check_config(config: Self::Config) { - GMM::check_config(config.to_gmm_config()) - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - GMM::check_availability::(client) - } - - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - let gmm_config = GMM::make_config(problem, cube_dim, cube_count, advanced_config); - let cube_count = if let CubeCount::Static(x, y, z) = cube_count { - (*x, *y, *z) - } else { - panic!("Dynamic cube count unsupported") - }; - - Config::new(gmm_config, cube_count) - } -} - -impl, S: SpanMatmul, C: CubeDispatch> MatmulLaunch - for Matmul -{ - unsafe fn launch_unchecked<'a, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, - cube_dim: CubeDim, - cube_count: CubeCount, - input: InputRuntimeArg<'a, MS, R>, - output: OutputRuntimeArg<'a, MS, R>, - config: Self::Config, - ) { - Self::check_config(config); - super::batch_matmul::launch_unchecked::( - client, cube_count, cube_dim, input, output, config, - ); + S::execute::(lhs, rhs, out, span, acc, k_range, gmm_config); } } #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for the OneToOneBatchMatmul -pub struct Config { +pub struct Config { gmm_config: G, cube_count: (u32, u32, u32), _c: PhantomData, } -impl batch::Config for Config { +impl batch::BatchConfig for Config { type GmmConfig = G; fn to_gmm_config(&self) -> Self::GmmConfig { @@ -151,9 +176,9 @@ impl batch::Config for Config { } } -impl MatmulConfig for Config {} +impl MatmulConfig for Config {} -impl Config { +impl Config { pub fn new(gmm_config: G, cube_count: (u32, u32, u32)) -> Self { Self { gmm_config, diff --git a/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs b/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs index 2139d7a69..e3ba06ae6 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs @@ -1,81 +1,57 @@ use std::marker::PhantomData; use crate::matmul::components::batch::shared::gmm_execute; -use crate::matmul::components::global::args::{TensorInput, TensorOutput}; +use crate::matmul::components::global::{GlobalMatmul, GlobalMatmulFamily}; use crate::matmul::components::{ - batch, config::MatmulConfig, global, Ident, MatmulKernel, MatmulLaunch, StageDim, + batch, config::MatmulConfig, global, Ident, MatmulConfigFactory, MatmulLaunch, StageDim, +}; +use crate::matmul::components::{ + InputRuntimeArg, InvalidConfigError, MatmulPrecision, MatmulProblem, MatmulSpec, + OutputRuntimeArg, }; -use crate::matmul::components::{InputRuntimeArg, MatmulProblem, MatmulSpec, OutputRuntimeArg}; use crate::matmul::kernels::matmul::AdvancedConfig; use crate::matmul::kernels::MatmulAvailabilityError; +use crate::tensor::{ReadWrite, VirtualTensor}; +use batch::{BatchMatmul, BatchMatmulFamily}; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::{Config as _, CubeDispatch}; +use super::{BatchConfig as _, CubeDispatch}; -/// Executes matrix multiplication at the batch level, -/// assigning each cube to a single global matmul. -/// -/// Note: This algorithm requires one cube per global matmul; -/// insufficient cubes will result in incomplete computations. -pub struct Matmul, C: CubeDispatch> { - _ms: PhantomData, +pub struct OneToOneMatmulFamily { _gmm: PhantomData, _c: PhantomData, } -#[cube] -impl, C: CubeDispatch> batch::Matmul - for Matmul -{ - fn execute( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, - #[comptime] config: Self::Config, - ) { - let (x_index, y_index) = C::x_y_indices(); - let x_offset = x_index * config.stage_dim(Ident::Lhs).num_elements_x_dim(); - let y_offset = y_index * config.stage_dim(Ident::Rhs).num_elements_y_dim(); - let nth_batch = C::batch_index(); - let rank = lhs.rank(); - let k_range = (0, lhs.shape(rank - 1)); - - let gmm_config = config.to_gmm_config(); - gmm_execute::( - lhs, - rhs, - out, - x_offset, - y_offset, - nth_batch, - &mut GMM::init_accumulator(gmm_config), - k_range, - gmm_config, - ); - } +impl BatchMatmulFamily for OneToOneMatmulFamily { + type Matmul = OneToOneMatmul, C>; } -impl, C: CubeDispatch> MatmulKernel for Matmul { +impl MatmulConfigFactory + for OneToOneMatmulFamily +{ + type Input = GMM::Input; type Config = Config; - fn check_config(config: Self::Config) { - GMM::check_config(config.to_gmm_config()) + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + GMM::check_config(&config.to_gmm_config()) } - fn check_availability( + fn check_availability( client: &ComputeClient, + config: &Self::Config, ) -> Result<(), MatmulAvailabilityError> { - GMM::check_availability::(client) + GMM::check_availability::(client, &config.gmm_config) } fn make_config( + input: Self::Input, problem: &MatmulProblem, cube_dim: &CubeDim, cube_count: &CubeCount, advanced_config: &AdvancedConfig, ) -> Self::Config { - let gmm_config = GMM::make_config(problem, cube_dim, cube_count, advanced_config); + let gmm_config = GMM::make_config(input, problem, cube_dim, cube_count, advanced_config); let cube_count = if let CubeCount::Static(x, y, z) = cube_count { (*x, *y, *z) } else { @@ -86,10 +62,8 @@ impl, C: CubeDispatch> MatmulKernel for } } -impl, C: CubeDispatch> MatmulLaunch - for Matmul -{ - unsafe fn launch_unchecked<'a, R: Runtime>( +impl MatmulLaunch for OneToOneMatmulFamily { + unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( client: &ComputeClient<::Server, ::Channel>, cube_dim: CubeDim, cube_count: CubeCount, @@ -97,22 +71,66 @@ impl, C: CubeDispatch> MatmulLaunch output: OutputRuntimeArg<'a, MS, R>, config: Self::Config, ) { - Self::check_config(config); - super::batch_matmul::launch_unchecked::( + super::matmul::launch_unchecked::( client, cube_count, cube_dim, input, output, config, ); } } +/// Executes matrix multiplication at the batch level, +/// assigning each cube to a single global matmul. +/// +/// Note: This algorithm requires one cube per global matmul; +/// insufficient cubes will result in incomplete computations. +pub struct OneToOneMatmul, C: CubeDispatch> { + _mp: PhantomData, + _gmm: PhantomData, + _c: PhantomData, +} + +#[cube] +impl, C: CubeDispatch> BatchMatmul + for OneToOneMatmul +{ + type Config = Config; + + fn execute( + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, + #[comptime] config: Self::Config, + ) { + let (x_index, y_index) = C::x_y_indices(); + let x_offset = x_index * config.stage_dim(Ident::Lhs).num_elements_x_dim(); + let y_offset = y_index * config.stage_dim(Ident::Rhs).num_elements_y_dim(); + let nth_batch = C::batch_index(); + let rank = lhs.rank(); + let k_range = (0, lhs.shape(rank - 1)); + + let gmm_config = config.to_gmm_config(); + gmm_execute::( + lhs, + rhs, + out, + x_offset, + y_offset, + nth_batch, + &mut GMM::init_accumulator(gmm_config), + k_range, + gmm_config, + ); + } +} + #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for the OneToOneBatchMatmul -pub struct Config { +pub struct Config { gmm_config: G, cube_count: (u32, u32, u32), _c: PhantomData, } -impl batch::Config for Config { +impl batch::BatchConfig for Config { type GmmConfig = G; fn to_gmm_config(&self) -> Self::GmmConfig { @@ -136,9 +154,9 @@ impl batch::Config for Config { } } -impl MatmulConfig for Config {} +impl MatmulConfig for Config {} -impl Config { +impl Config { pub fn new(gmm_config: G, cube_count: (u32, u32, u32)) -> Self { Self { gmm_config, diff --git a/crates/cubecl-linalg/src/matmul/components/batch/shared.rs b/crates/cubecl-linalg/src/matmul/components/batch/shared.rs index 9d2f27274..b33058eb9 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/shared.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/shared.rs @@ -1,17 +1,16 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::global::args::{TensorInput, TensorOutput}; -use crate::matmul::components::{global, MatmulSpec}; +use crate::matmul::components::{global, MatmulPrecision}; use crate::tensor::{ReadWrite, VirtualTensor}; #[cube] /// Execute global matmul on lhs, rhs, writing in out. /// x and y offsets are absolute rows and columns -pub(crate) fn gmm_execute>( - lhs: TensorInput, - rhs: TensorInput, - mut out: TensorOutput, +pub(crate) fn gmm_execute>( + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, x_offset: u32, y_offset: u32, nth_batch: u32, @@ -29,10 +28,6 @@ pub(crate) fn gmm_execute>( batch_rhs += tmp % rhs.shape(b) * rhs.stride(b); } - let lhs = VirtualTensor::::new::>(&lhs); - let rhs = VirtualTensor::::new::>(&rhs); - let out = VirtualTensor::::new::>(&mut out); - GMM::execute( GMM::init_lhs_loader(lhs, x_offset, k_range.0, batch_lhs, config), GMM::init_rhs_loader(rhs, k_range.0, y_offset, batch_rhs, config), diff --git a/crates/cubecl-linalg/src/matmul/components/batch/span.rs b/crates/cubecl-linalg/src/matmul/components/batch/span.rs index 90ce27796..ea6119222 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/span.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/span.rs @@ -1,13 +1,13 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::{ - batch::shared::swizzle, - global::{ - self, - args::{TensorInput, TensorOutput}, +use crate::{ + matmul::components::{ + batch::shared::swizzle, + global::{self}, + MatmulPrecision, }, - MatmulSpec, + tensor::{ReadWrite, VirtualTensor}, }; use super::shared::gmm_execute; @@ -32,10 +32,10 @@ pub struct SpanDim { #[cube] /// Iterates on several global matmul across a span pub trait SpanMatmul: 'static + Send + Sync { - fn execute>( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, + fn execute>( + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, span: Span, acc: GMM::Accumulator, k_range: (u32, u32), @@ -89,10 +89,10 @@ impl SpanDim { #[cube] impl SpanMatmul for RowMajorSpanMatmul { - fn execute>( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, + fn execute>( + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, span: Span, mut acc: GMM::Accumulator, k_range: (u32, u32), @@ -102,7 +102,7 @@ impl SpanMatmul for RowMajorSpanMatmul { for row_iter in range_stepped(span.row.start, span.row.end, span.row.step) { for col_iter in range_stepped(span.col.start, span.col.end, span.col.step) { GMM::zero_accumulator(&mut acc, config); - gmm_execute::( + gmm_execute::( lhs, rhs, out, row_iter, col_iter, batch_iter, &mut acc, k_range, config, ); } @@ -113,10 +113,10 @@ impl SpanMatmul for RowMajorSpanMatmul { #[cube] impl SpanMatmul for ColMajorSpanMatmul { - fn execute>( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, + fn execute>( + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, span: Span, mut acc: GMM::Accumulator, k_range: (u32, u32), @@ -126,7 +126,7 @@ impl SpanMatmul for ColMajorSpanMatmul { for col_iter in range_stepped(span.col.start, span.col.end, span.col.step) { for row_iter in range_stepped(span.row.start, span.row.end, span.row.step) { GMM::zero_accumulator(&mut acc, config); - gmm_execute::( + gmm_execute::( lhs, rhs, out, row_iter, col_iter, batch_iter, &mut acc, k_range, config, ); } @@ -137,10 +137,10 @@ impl SpanMatmul for ColMajorSpanMatmul { #[cube] impl SpanMatmul for SwizzleSpanMatmul { - fn execute>( - lhs: TensorInput, - rhs: TensorInput, - out: TensorOutput, + fn execute>( + lhs: VirtualTensor, + rhs: VirtualTensor, + out: VirtualTensor, span: Span, mut acc: GMM::Accumulator, k_range: (u32, u32), @@ -155,7 +155,7 @@ impl SpanMatmul for SwizzleSpanMatmul { let row_iter = span.row.start + row * span.row.step; let col_iter = span.col.start + col * span.col.step; - gmm_execute::( + gmm_execute::( lhs, rhs, out, row_iter, col_iter, batch_iter, &mut acc, k_range, config, ); } diff --git a/crates/cubecl-linalg/src/matmul/components/config.rs b/crates/cubecl-linalg/src/matmul/components/config.rs index 997544593..3a6fc1281 100644 --- a/crates/cubecl-linalg/src/matmul/components/config.rs +++ b/crates/cubecl-linalg/src/matmul/components/config.rs @@ -1,8 +1,59 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::hash::Hash; +use crate::matmul::kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}; + +use super::{MatmulPrecision, MatmulProblem}; + +pub type InvalidConfigError = Box; + +pub struct FormattedConfigError { + func: Box String>, +} + +impl FormattedConfigError { + #[allow(clippy::new_ret_no_self)] + pub fn new String + 'static>(func: F) -> Box { + Box::new(Self { + func: Box::new(func), + }) + } +} + +impl Display for FormattedConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let string = (self.func)(); + write!(f, "{string}") + } +} + +/// Provides configuration for a matmul kernel at any level +pub trait MatmulConfigFactory: Send + Sync + 'static { + /// Configuration tailored to the matmul implementation + type Config: MatmulConfig; + type Input; + + /// Asserts that the configuration for this matmul will lead to a valid computation + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>; + + /// Checks if the client can handle the features used in this computation + fn check_availability( + _client: &ComputeClient, + _config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError>; + + /// Create config for this matmul, given launch information + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config; +} + /// A config for a matmul /// /// Useful to aggregate many trait bounds diff --git a/crates/cubecl-linalg/src/matmul/components/global/accumulator_loader.rs b/crates/cubecl-linalg/src/matmul/components/global/accumulator_loader.rs index ac9d52ff1..2494012b7 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/accumulator_loader.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/accumulator_loader.rs @@ -1,7 +1,7 @@ use cubecl_core as cubecl; use cubecl_core::{prelude::*, CubeType}; -use crate::matmul::components::{stage::Config, tile}; +use crate::matmul::components::{stage::StageConfig, tile}; use super::AccumulatorLoader; @@ -10,10 +10,12 @@ use super::AccumulatorLoader; pub struct ZeroAccumulatorLoader; #[cube] -impl AccumulatorLoader for ZeroAccumulatorLoader { +impl AccumulatorLoader + for ZeroAccumulatorLoader +{ fn fill_stage(_this: &mut Self, #[comptime] _config: G) {} - fn load>( + fn load>( _this: &mut Self, acc: &mut Tile::Accumulator, _n_tile: u32, diff --git a/crates/cubecl-linalg/src/matmul/components/global/args.rs b/crates/cubecl-linalg/src/matmul/components/global/args.rs index cd37ac8b2..a5d8f3db2 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/args.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/args.rs @@ -4,46 +4,49 @@ use cubecl_core::{self as cubecl}; #[cube] /// Arguments for the matrix multiplication algorithm. -pub trait MatmulArgs: Send + Sync + 'static + Clone { +pub trait MatmulArgs: Send + Sync + 'static + Clone { /// Type used for the input. - type Input: LaunchArg + CubeType; + type Input: LaunchArg + CubeType; /// Type used for the output. - type Output: LaunchArg + CubeType; + type Output: LaunchArg + CubeType; /// Inner state that is used to create [tensor inputs](TensorInput) and /// [tensor outputs](TensorOutput) . - type State: CubeType; + type State: CubeType; /// Init the state. - fn init_state(input: &Self::Input, output: &mut Self::Output) -> Self::State; + fn init_state( + input: &Self::Input, + output: &mut Self::Output, + ) -> Self::State; /// Read the line of the lhs tensor using the state at the given coordinate. - fn read_lhs(state: &Self::State, coordinate: u32) -> Line; + fn read_lhs(state: &Self::State, coordinate: u32) -> Line; /// Read the line of the rhs tensor using the state at the given coordinate. - fn read_rhs(state: &Self::State, coordinate: u32) -> Line; + fn read_rhs(state: &Self::State, coordinate: u32) -> Line; /// Write the line to the output at the given coordinate using the state. - fn write_out(state: &mut Self::State, coordinate: u32, value: Line); + fn write_out(state: &mut Self::State, coordinate: u32, value: Line); /// Get the rank of the lhs tensor using the state. - fn rank_lhs(state: &Self::State) -> u32; + fn rank_lhs(state: &Self::State) -> u32; /// Get the rank of the rhs tensor using the state. - fn rank_rhs(state: &Self::State) -> u32; + fn rank_rhs(state: &Self::State) -> u32; /// Get the rank of the out tensor using the state. - fn rank_out(state: &Self::State) -> u32; + fn rank_out(state: &Self::State) -> u32; /// Get the shape of the lhs tensor using the state. - fn shape_lhs(state: &Self::State, axis: u32) -> u32; + fn shape_lhs(state: &Self::State, axis: u32) -> u32; /// Get the shape of the rhs tensor using the state. - fn shape_rhs(state: &Self::State, axis: u32) -> u32; + fn shape_rhs(state: &Self::State, axis: u32) -> u32; /// Get the shape of the out tensor using the state. - fn shape_out(state: &Self::State, axis: u32) -> u32; + fn shape_out(state: &Self::State, axis: u32) -> u32; /// Get the stride of the lhs tensor using the state. - fn stride_lhs(state: &Self::State, axis: u32) -> u32; + fn stride_lhs(state: &Self::State, axis: u32) -> u32; /// Get the stride of the rhs tensor using the state. - fn stride_rhs(state: &Self::State, axis: u32) -> u32; + fn stride_rhs(state: &Self::State, axis: u32) -> u32; /// Get the stride of the out tensor using the state. - fn stride_out(state: &Self::State, axis: u32) -> u32; + fn stride_out(state: &Self::State, axis: u32) -> u32; } #[derive(Clone, Copy)] @@ -56,17 +59,15 @@ pub enum TensorInputIdent { /// Tensor input representation. /// /// You can use the tensor input as if it was a pointer to the actually tensor. -pub struct TensorInput> { - state: *const GA::State, +pub struct TensorInput { + state: *const GA::State, ident: TensorInputIdent, } -impl> VirtualTensorOperations for TensorInput {} -impl> VirtualTensorOperations for TensorOutput {} +impl VirtualTensorOperations for TensorInput {} +impl VirtualTensorOperations for TensorOutput {} -impl> VirtualTensorOperationsExpand - for TensorOutputExpand -{ +impl VirtualTensorOperationsExpand for TensorOutputExpand { fn __expand_read_method( &self, _context: &mut CubeContext, @@ -105,9 +106,7 @@ impl> VirtualTensorOperationsExpand } } -impl> VirtualTensorOperationsExpand - for TensorInputExpand -{ +impl VirtualTensorOperationsExpand for TensorInputExpand { fn __expand_read_method( &self, context: &mut CubeContext, @@ -153,25 +152,25 @@ impl> VirtualTensorOperationsExpand /// # Warning /// /// There is no mutability guarantee. -pub struct TensorOutput> { - state: *mut GA::State, +pub struct TensorOutput { + state: *mut GA::State, } /// Expand type for [tensor input](TensorInput). -pub struct TensorInputExpand> { - state: ::ExpandType, +pub struct TensorInputExpand { + state: as CubeType>::ExpandType, ident: TensorInputIdent, } /// Expand type for [tensor output](TensorOutput). -pub struct TensorOutputExpand> { - state: ::ExpandType, +pub struct TensorOutputExpand { + state: as CubeType>::ExpandType, } #[cube] -impl> TensorInput { +impl TensorInput { /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State, #[comptime] ident: TensorInputIdent) -> TensorInput { + pub fn new(state: &MA::State, #[comptime] ident: TensorInputIdent) -> TensorInput { TensorInput:: { state, ident } } @@ -217,9 +216,9 @@ impl> TensorInput { } #[cube] -impl> TensorOutput { +impl TensorOutput { /// Create a [tensor output](TensorOutput) from the state. - pub fn new(state: &mut GA::State) -> TensorOutput { + pub fn new(state: &mut GA::State) -> TensorOutput { TensorOutput:: { state } } @@ -260,64 +259,67 @@ pub struct TensorInputs { } #[cube] -impl MatmulArgs for TensorArgs { - type Output = Tensor>; - type Input = TensorInputs; - type State = ( +impl MatmulArgs for TensorArgs { + type Output = Tensor>; + type Input = TensorInputs; + type State = ( *const Tensor>, *const Tensor>, *mut Tensor>, ); - fn init_state(input: &Self::Input, output: &mut Self::Output) -> Self::State { + fn init_state( + input: &Self::Input, + output: &mut Self::Output, + ) -> Self::State { (&input.lhs, &input.rhs, output) } - fn read_lhs(state: &Self::State, coordinate: u32) -> Line { + fn read_lhs(state: &Self::State, coordinate: u32) -> Line { unsafe { (*state.0)[coordinate] } } - fn read_rhs(state: &Self::State, coordinate: u32) -> Line { + fn read_rhs(state: &Self::State, coordinate: u32) -> Line { unsafe { (*state.1)[coordinate] } } - fn shape_lhs(state: &Self::State, dim: u32) -> u32 { + fn shape_lhs(state: &Self::State, dim: u32) -> u32 { unsafe { (*state.0).shape(dim) } } - fn shape_rhs(state: &Self::State, dim: u32) -> u32 { + fn shape_rhs(state: &Self::State, dim: u32) -> u32 { unsafe { (*state.1).shape(dim) } } - fn shape_out(state: &Self::State, dim: u32) -> u32 { + fn shape_out(state: &Self::State, dim: u32) -> u32 { unsafe { (*state.2).shape(dim) } } - fn stride_lhs(state: &Self::State, dim: u32) -> u32 { + fn stride_lhs(state: &Self::State, dim: u32) -> u32 { unsafe { (*state.0).stride(dim) } } - fn stride_rhs(state: &Self::State, dim: u32) -> u32 { + fn stride_rhs(state: &Self::State, dim: u32) -> u32 { unsafe { (*state.1).stride(dim) } } - fn stride_out(state: &Self::State, dim: u32) -> u32 { + fn stride_out(state: &Self::State, dim: u32) -> u32 { unsafe { (*state.2).stride(dim) } } - fn write_out(state: &mut Self::State, coordinate: u32, value: Line) { + fn write_out(state: &mut Self::State, coordinate: u32, value: Line) { unsafe { (*state.2)[coordinate] = value } } - fn rank_lhs(state: &Self::State) -> u32 { + fn rank_lhs(state: &Self::State) -> u32 { unsafe { (*state.0).rank() } } - fn rank_rhs(state: &Self::State) -> u32 { + fn rank_rhs(state: &Self::State) -> u32 { unsafe { (*state.1).rank() } } - fn rank_out(state: &Self::State) -> u32 { + fn rank_out(state: &Self::State) -> u32 { unsafe { (*state.2).rank() } } } @@ -325,11 +327,11 @@ impl MatmulArgs for TensorArgs { mod __input { use super::*; - impl> CubeType for TensorInput { + impl CubeType for TensorInput { type ExpandType = TensorInputExpand; } - impl> Clone for TensorInputExpand { + impl Clone for TensorInputExpand { fn clone(&self) -> Self { Self { state: self.state.clone(), @@ -338,20 +340,20 @@ mod __input { } } - impl> Init for TensorInputExpand { + impl Init for TensorInputExpand { fn init(mut self, context: &mut CubeContext) -> Self { self.state = self.state.init(context); self } } - impl> Clone for TensorInput { + impl Clone for TensorInput { fn clone(&self) -> Self { *self } } - impl> Copy for TensorInput {} + impl Copy for TensorInput {} - impl> IntoRuntime for TensorInput { + impl IntoRuntime for TensorInput { fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { panic!("Can't exist at compile time"); } @@ -361,17 +363,17 @@ mod __input { mod __output { use super::*; - impl> CubeType for TensorOutput { + impl CubeType for TensorOutput { type ExpandType = TensorOutputExpand; } - impl> Clone for TensorOutput { + impl Clone for TensorOutput { fn clone(&self) -> Self { *self } } - impl> Clone for TensorOutputExpand { + impl Clone for TensorOutputExpand { fn clone(&self) -> Self { Self { state: self.state.clone(), @@ -379,16 +381,16 @@ mod __output { } } - impl> Init for TensorOutputExpand { + impl Init for TensorOutputExpand { fn init(mut self, context: &mut CubeContext) -> Self { self.state = self.state.init(context); self } } - impl> Copy for TensorOutput {} + impl Copy for TensorOutput {} - impl> IntoRuntime for TensorOutput { + impl IntoRuntime for TensorOutput { fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { panic!("Can't exist at compile time"); } diff --git a/crates/cubecl-linalg/src/matmul/components/global/base.rs b/crates/cubecl-linalg/src/matmul/components/global/base.rs index b687599ae..d2fe97cec 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/base.rs @@ -2,12 +2,19 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use crate::matmul::components::stage::{self, StageWriter, TilingOrderConfig}; -use crate::matmul::components::StageDim; use crate::matmul::components::{config::MatmulConfig, tile}; use crate::matmul::components::{Ident, MatrixLayout}; -use crate::matmul::components::{MatmulKernel, MatmulSpec}; +use crate::matmul::components::{InvalidConfigError, MatmulConfigFactory}; +use crate::matmul::components::{MatmulPrecision, StageDim}; use crate::tensor::{ReadWrite, VirtualTensor}; +/// A family of [matmuls](GlobalMatmul) working with any [precision](MatmulPrecision). +pub trait GlobalMatmulFamily: + MatmulConfigFactory + Send + Sync + 'static +{ + type Matmul: GlobalMatmul; +} + #[cube] /// Provides matrix multiplication operations at the global level. /// @@ -27,11 +34,12 @@ use crate::tensor::{ReadWrite, VirtualTensor}; /// It is not assumed that the matmul's dimensions match its inputs dimensions perfectly. /// It is therefore important that Loaders and Unloaders perform checks to avoid out-of-bounds /// before loading data. -pub trait Matmul: 'static + Send + Sync + MatmulKernel { - type LhsLoader: Loader; - type RhsLoader: Loader; +pub trait GlobalMatmul: 'static + Send + Sync { + type Config: GlobalConfig; + type LhsLoader: InputLoader; + type RhsLoader: InputLoader; type AccumulatorLoader: CubeType; - type Out: Unloader; + type Out: OutputLoader; type Accumulator: CubeType; /// Performs the matrix multiplication over data loaded by the @@ -51,7 +59,7 @@ pub trait Matmul: 'static + Send + Sync + MatmulKernel, + lhs: VirtualTensor, m_offset: u32, k_offset: u32, batch_offset: u32, @@ -60,7 +68,7 @@ pub trait Matmul: 'static + Send + Sync + MatmulKernel, + rhs: VirtualTensor, k_offset: u32, n_offset: u32, batch_offset: u32, @@ -69,7 +77,7 @@ pub trait Matmul: 'static + Send + Sync + MatmulKernel, + out: VirtualTensor, m_offset: u32, n_offset: u32, batch_offset: u32, @@ -85,7 +93,9 @@ pub trait Matmul: 'static + Send + Sync + MatmulKernel: CubeType + 'static + Send + Sync { +pub trait InputLoader: + CubeType + 'static + Send + Sync +{ /// The stage reader which matches the input of the underlying stage matmul. type StageReader: CubeType; @@ -102,14 +112,14 @@ pub trait Loader: CubeType + 'static + Send #[cube] /// Input to the global matmul accumulator, responsible of filling the stage and providing a reader /// for it. -pub trait AccumulatorLoader: +pub trait AccumulatorLoader: CubeType + 'static + Send + Sync { fn fill_stage(this: &mut Self, #[comptime] config: G); /// Load accumulator for `tile_n`. Should call either `zero_accumulator` or `fill_accumulator` /// for the underlying tile. - fn load>( + fn load>( this: &mut Self, acc: &mut Tile::Accumulator, tile_n: u32, @@ -124,16 +134,20 @@ pub trait AccumulatorLoader: /// /// It is only a wrapper over the stage writer because there is no K for the output. /// Could be deleted in favor of having only the StageWriter -pub trait Unloader: CubeType + 'static + Send + Sync { +pub trait OutputLoader: CubeType + 'static + Send + Sync { type StageWriter: StageWriter; - fn as_stage_writer(unloader: Self) -> Self::StageWriter; + fn as_stage_writer(unloader: Self) -> Self::StageWriter; +} + +pub trait LoadingValidation { + fn check(config: &C, ident: Ident) -> Result<(), InvalidConfigError>; } -/// Configuration for the Global matmul (GMM) level -pub trait Config: MatmulConfig { +/// Configuration for the [global matmul](GlobalMatmul) level. +pub trait GlobalConfig: MatmulConfig { /// Underlying Stage matmul config - type SmmConfig: stage::Config; + type SmmConfig: stage::StageConfig; /// Convert itself to the underlying stage matmul config fn to_smm_config(&self) -> Self::SmmConfig; diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/buffer_loading.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/buffer_loading.rs index b58fdd682..f179bc5fe 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/buffer_loading.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/buffer_loading.rs @@ -1,7 +1,7 @@ use crate::matmul::components::config::InputIdent; -use crate::matmul::components::global; use crate::matmul::components::global::tensor_view::TensorReader; -use crate::matmul::components::Ident; +use crate::matmul::components::global::{self, GlobalConfig, LoadingValidation}; +use crate::matmul::components::{Ident, InvalidConfigError}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -10,9 +10,35 @@ use cubecl_core::prelude::*; /// iterating with steps determined by the plane's dimension. pub struct BufferLoading {} +impl LoadingValidation for BufferLoading { + fn check(config: &C, ident: Ident) -> Result<(), InvalidConfigError> { + let stage_dim = config.stage_dim(ident); + let line_size = config.global_line_size(ident); + + let num_stage_elements = stage_dim.total_elements(); + let total_units = config.num_planes() * config.plane_dim(); + let jump_length = total_units * line_size; + + if num_stage_elements % jump_length != 0 { + return Err(Box::new( + "Too many data will be loaded, resulting in out of bounds. + Try setting line size and number of planes so that jump_length divides num_stage_elements.", + )); + } + + if config.transpose_load(ident) { + return Err(Box::new( + "Transpose load not yet supported in buffered setup", + )); + } + + Ok(()) + } +} + #[cube] impl BufferLoading { - pub fn load_to_slice( + pub fn load_to_slice( read_view: &TensorReader, buffer_slice: &mut SliceMut>, #[comptime] num_producer_planes: u32, @@ -29,9 +55,6 @@ impl BufferLoading { let jump_length = comptime!(total_units * line_size); let num_loads_per_unit = num_buffer_elements / jump_length; - #[allow(clippy::all)] - let _ = comptime!(check_jump_divides_well(num_buffer_elements, jump_length)); - let plane_id = if comptime!(producer_plane_offset > 0) { UNIT_POS_Y - producer_plane_offset } else { @@ -53,15 +76,7 @@ impl BufferLoading { let line_read = read_view.load_coalesced::(tile_x, tile_y, pos_within_tile, ident, config); - match config.transpose_load(ident) { - false => { - buffer_slice[unit_position / line_size] = Line::cast_from(line_read); - } - true => { - #[allow(clippy::all)] - let _ = comptime!(unsupported_transpose_load()); - } - } + buffer_slice[unit_position / line_size] = Line::cast_from(line_read); } } } @@ -79,15 +94,3 @@ fn get_tiles_x_y(nth_buffer_tile: u32, #[comptime] ident: Ident) -> (u32, u32) { } } } - -fn unsupported_transpose_load() { - panic!("Transpose load not yet supported in buffered setup") -} - -fn check_jump_divides_well(num_stage_elements: u32, jump_length: u32) { - assert!( - num_stage_elements % jump_length == 0, - "Too many data will be loaded, resulting in out of bounds. - Try setting line size and number of planes so that jump_length divides num_stage_elements." - ); -} diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs index 3decfa92d..309fedb52 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs @@ -1,15 +1,9 @@ -use crate::matmul::components::global::unloader::Unloader; -use crate::matmul::components::global::{Config as _, Loader}; +use crate::matmul::components::global::output_loader::Unloader; +use crate::matmul::components::global::{self, CommonGlobalConfig, InputLoader}; +use crate::matmul::components::global::{GlobalConfig, ZeroAccumulatorLoader}; use crate::matmul::components::stage::single_buffer::{LhsBufferReader, RhsBufferReader}; -use crate::matmul::components::stage::TilingOrderConfig; -use crate::matmul::components::MatmulKernel; -use crate::matmul::components::StageDim; -use crate::matmul::components::{config::MatmulConfig, global::ZeroAccumulatorLoader}; -use crate::matmul::components::{global, MatmulProblem}; -use crate::matmul::components::{stage, MatmulSpec}; -use crate::matmul::components::{Ident, MatrixLayout}; -use crate::matmul::kernels::matmul::AdvancedConfig; -use crate::matmul::kernels::MatmulAvailabilityError; +use crate::matmul::components::Ident; +use crate::matmul::components::{stage, MatmulPrecision}; use crate::tensor::{ReadWrite, VirtualTensor}; use cubecl_core as cubecl; @@ -21,26 +15,27 @@ use super::loader::{LhsBufferLoader, RhsBufferLoader}; /// Performs matrix multiplication at the global level, with planes pipelining their work using two buffers: /// While they trigger a load event from global memory to shared memory on buffer A, /// they trigger a computation event from tensor cores on buffer B. Then buffers are switched. -pub struct Matmul> { - _ms: PhantomData, +pub struct PipelinedMatmul> { + _ms: PhantomData, _stage_matmul: PhantomData, } #[cube] -impl global::Matmul for Matmul +impl global::GlobalMatmul for PipelinedMatmul where - SMM: stage::Matmul< - MS::ES, - MS::EG, - MS::EA, - LhsReader = LhsBufferReader, - RhsReader = RhsBufferReader, + SMM: stage::StageMatmul< + MP::ES, + MP::EG, + MP::EA, + LhsReader = LhsBufferReader, + RhsReader = RhsBufferReader, >, { - type LhsLoader = LhsBufferLoader; - type RhsLoader = RhsBufferLoader; + type Config = CommonGlobalConfig; + type LhsLoader = LhsBufferLoader; + type RhsLoader = RhsBufferLoader; type AccumulatorLoader = ZeroAccumulatorLoader; - type Out = Unloader; + type Out = Unloader; type Accumulator = SMM::Accumulator; fn execute( @@ -135,7 +130,7 @@ where } fn init_lhs_loader( - lhs: VirtualTensor, + lhs: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -145,7 +140,7 @@ where } fn init_rhs_loader( - rhs: VirtualTensor, + rhs: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -155,7 +150,7 @@ where } fn init_unloader( - out: VirtualTensor, + out: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -171,152 +166,3 @@ where SMM::zero_accumulator(acc, config.to_smm_config()); } } - -impl MatmulKernel for Matmul -where - SMM: stage::Matmul, -{ - type Config = Config; - - fn check_config(config: Self::Config) { - assert!( - config.stage_dim(Ident::Lhs).num_tiles_y_dim() == 2, - "Pipelined matmul needs exactly 2 buffers." - ); - SMM::check_config(config.to_smm_config()); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - SMM::check_availability::(client) - } - - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - let smm_config = SMM::make_config(problem, cube_dim, cube_count, advanced_config); - - Config::new( - smm_config, - problem.m as u32 % SMM::M != 0, - problem.n as u32 % SMM::N != 0, - problem.k as u32 % SMM::K != 0, - problem.lhs_layout, - problem.rhs_layout, - problem.lhs_line_size as u32, - problem.rhs_line_size as u32, - problem.out_line_size as u32, - cube_dim.y, - ) - } -} - -#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] -/// Configuration for the pipelined global matmul -pub struct Config { - smm_config: S, - check_m_bounds: bool, - check_n_bounds: bool, - check_k_bounds: bool, - lhs_layout: MatrixLayout, - rhs_layout: MatrixLayout, - lhs_line_size: u32, - rhs_line_size: u32, - out_line_size: u32, - num_planes: u32, -} - -impl global::Config for Config { - type SmmConfig = S; - - fn to_smm_config(&self) -> Self::SmmConfig { - self.smm_config - } - - fn global_line_size(&self, ident: Ident) -> u32 { - match ident { - Ident::Lhs => self.lhs_line_size, - Ident::Rhs => self.rhs_line_size, - Ident::Out => self.out_line_size, - } - } - - fn stage_line_size(&self, ident: Ident) -> u32 { - self.smm_config.line_size(ident) - } - - fn stage_dim(&self, ident: Ident) -> Box { - self.smm_config.stage_dim(ident) - } - - fn layout(&self, ident: Ident) -> MatrixLayout { - match ident { - Ident::Lhs => self.lhs_layout, - Ident::Rhs => self.rhs_layout, - Ident::Out => self.smm_config.layout(Ident::Out), - } - } - - fn num_planes(&self) -> u32 { - self.num_planes - } - - fn plane_dim(&self) -> u32 { - self.smm_config.plane_dim() - } - - fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { - self.smm_config.tiling_order(ident) - } - - fn check_m_bounds(&self) -> bool { - self.check_m_bounds - } - - fn check_n_bounds(&self) -> bool { - self.check_n_bounds - } - - fn check_k_bounds(&self) -> bool { - self.check_k_bounds - } - - fn transpose_load(&self, ident: Ident) -> bool { - self.layout(ident) != self.smm_config.layout(ident) - } -} - -impl MatmulConfig for Config {} - -impl Config { - #[allow(clippy::too_many_arguments)] - pub fn new( - smm_config: S, - check_m_bounds: bool, - check_n_bounds: bool, - check_k_bounds: bool, - lhs_layout: MatrixLayout, - rhs_layout: MatrixLayout, - lhs_line_size: u32, - rhs_line_size: u32, - out_line_size: u32, - num_planes: u32, - ) -> Self { - Self { - smm_config, - check_m_bounds, - check_n_bounds, - check_k_bounds, - lhs_layout, - rhs_layout, - lhs_line_size, - rhs_line_size, - out_line_size, - num_planes, - } - } -} diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/family.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/family.rs new file mode 100644 index 000000000..eb6467c5b --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/family.rs @@ -0,0 +1,85 @@ +use crate::matmul::components::global::buffered::buffer_loading::BufferLoading; +use crate::matmul::components::global::{ + CommonGlobalConfig, GlobalConfig, GlobalMatmulFamily, LoadingValidation, +}; +use crate::matmul::components::stage::single_buffer::{ + LhsBufferReaderFamily, RhsBufferReaderFamily, +}; +use crate::matmul::components::MatmulConfigFactory; +use crate::matmul::components::MatmulProblem; +use crate::matmul::components::{stage, MatmulPrecision}; +use crate::matmul::components::{Ident, InvalidConfigError}; +use crate::matmul::kernels::matmul::AdvancedConfig; +use crate::matmul::kernels::MatmulAvailabilityError; +use cubecl_core::prelude::*; +use std::marker::PhantomData; + +use super::loader::check_buffers_contiguous; +use super::PipelinedMatmul; + +pub struct PipelinedMatmulFamily { + _stage_matmul: PhantomData, +} + +impl GlobalMatmulFamily for PipelinedMatmulFamily +where + SMM: stage::StageMatmulFamily< + LhsReader = LhsBufferReaderFamily, + RhsReader = RhsBufferReaderFamily, + >, +{ + type Matmul = PipelinedMatmul>; +} + +impl MatmulConfigFactory for PipelinedMatmulFamily +where + SMM: stage::StageMatmulFamily, +{ + type Input = SMM::Input; + type Config = CommonGlobalConfig; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + check_buffers_contiguous::(Ident::Lhs, config)?; + check_buffers_contiguous::(Ident::Rhs, config)?; + + BufferLoading::check::(config, Ident::Lhs)?; + BufferLoading::check::(config, Ident::Rhs)?; + + if config.stage_dim(Ident::Lhs).num_tiles_y_dim() != 2 { + return Err(Box::new("Pipelined matmul needs exactly 2 buffers.")); + } + + SMM::check_config(&config.to_smm_config()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + SMM::check_availability::(client, &config.smm_config) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let smm_config = SMM::make_config(input, problem, cube_dim, cube_count, advanced_config); + let size = SMM::size(&smm_config); + + CommonGlobalConfig::new( + smm_config, + problem.m as u32 % size.m != 0, + problem.n as u32 % size.n != 0, + problem.k as u32 % size.k != 0, + problem.lhs_layout, + problem.rhs_layout, + problem.lhs_line_size as u32, + problem.rhs_line_size as u32, + problem.out_line_size as u32, + cube_dim.y, + ) + } +} diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/loader.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/loader.rs index 2278abe6a..5fbd2d0d4 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/loader.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/loader.rs @@ -1,21 +1,20 @@ use std::marker::PhantomData; use crate::matmul::components::config::InputIdent; -use crate::matmul::components::global::base::Config as _; +use crate::matmul::components::global::base::GlobalConfig as _; use crate::matmul::components::global::buffered::buffer_loading::BufferLoading; -use crate::matmul::components::global::buffered::pipelined; use crate::matmul::components::global::tensor_view::TensorReader; -use crate::matmul::components::global::Loader; +use crate::matmul::components::global::{CommonGlobalConfig, InputLoader}; use crate::matmul::components::stage::single_buffer::{LhsBufferReader, RhsBufferReader}; use crate::matmul::components::stage::TilingOrderConfig; use crate::matmul::components::stage::{self, Stage}; -use crate::matmul::components::{global, Ident}; +use crate::matmul::components::{global, Ident, InvalidConfigError}; use crate::tensor::VirtualTensor; use cubecl_core as cubecl; use cubecl_core::prelude::*; #[derive(CubeType)] -pub struct LhsBufferLoader { +pub struct LhsBufferLoader { pub tensor_view: TensorReader, pub stage: Stage, buffer_iter: u32, @@ -24,7 +23,7 @@ pub struct LhsBufferLoader { } #[derive(CubeType)] -pub struct RhsBufferLoader { +pub struct RhsBufferLoader { pub tensor_view: TensorReader, pub stage: Stage, buffer_iter: u32, @@ -33,12 +32,12 @@ pub struct RhsBufferLoader { } #[cube] -impl Loader> +impl InputLoader> for LhsBufferLoader { type StageReader = LhsBufferReader; - fn fill_stage(this: &mut Self, #[comptime] config: pipelined::Config) { + fn fill_stage(this: &mut Self, #[comptime] config: CommonGlobalConfig) { load_buffer::( this.buffer_iter, &this.tensor_view, @@ -62,13 +61,13 @@ impl Loader LhsBufferLoader { +impl LhsBufferLoader { pub fn new( tensor: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, - #[comptime] config: pipelined::Config, + #[comptime] config: CommonGlobalConfig, ) -> Self { let stage = Stage::new::(Ident::Lhs, config.to_smm_config()); let tensor_view = TensorReader::new(tensor, x_offset, y_offset, batch_offset); @@ -84,12 +83,12 @@ impl LhsBufferLoader { } #[cube] -impl Loader> +impl InputLoader> for RhsBufferLoader { type StageReader = RhsBufferReader; - fn fill_stage(this: &mut Self, #[comptime] config: pipelined::Config) { + fn fill_stage(this: &mut Self, #[comptime] config: CommonGlobalConfig) { load_buffer::( this.buffer_iter, &this.tensor_view, @@ -113,13 +112,13 @@ impl Loader RhsBufferLoader { +impl RhsBufferLoader { pub fn new( tensor: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, - #[comptime] config: pipelined::Config, + #[comptime] config: CommonGlobalConfig, ) -> Self { let stage = Stage::new::(Ident::Rhs, config.to_smm_config()); let tensor_view = TensorReader::new(tensor, x_offset, y_offset, batch_offset); @@ -135,25 +134,22 @@ impl RhsBufferLoader { } #[cube] -fn load_buffer( +fn load_buffer( buffer_iter: u32, tensor_view: &TensorReader, stage: &mut Stage, #[comptime] ident: Ident, - #[comptime] config: pipelined::Config, + #[comptime] config: CommonGlobalConfig, ) { let buffer_num_elements = config.stage_dim(ident).buffer_num_elements(); let line_size = config.stage_line_size(ident); let buffer_num_lines = buffer_num_elements / line_size; - #[allow(clippy::all)] - let _ = comptime!(check_buffers_contiguous(ident, config)); - let start = buffer_iter * buffer_num_lines; let end = start + buffer_num_lines; let buffer_slice = &mut stage.as_slice_mut().slice_mut(start, end); - BufferLoading::load_to_slice::>( + BufferLoading::load_to_slice::>( tensor_view, buffer_slice, config.num_planes(), @@ -163,17 +159,26 @@ fn load_buffer( ); } -fn check_buffers_contiguous(ident: Ident, config: G) { +pub fn check_buffers_contiguous( + ident: Ident, + config: &G, +) -> Result<(), InvalidConfigError> { match ident.as_input() { InputIdent::Lhs => { if let TilingOrderConfig::RowMajor = config.tiling_order(ident) { - panic!("Lhs must have ColMajor tiling order in pipelined setting") + return Err(Box::new( + "Lhs must have ColMajor tiling order in pipelined setting", + )); } } InputIdent::Rhs => { if let TilingOrderConfig::ColMajor = config.tiling_order(ident) { - panic!("Rhs must have RowMajor tiling order in pipelined setting") + return Err(Box::new( + "Rhs must have RowMajor tiling order in pipelined setting", + )); } } } + + Ok(()) } diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/mod.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/mod.rs index 741e64de6..353aa3060 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/mod.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/mod.rs @@ -1,4 +1,6 @@ mod base; +mod family; mod loader; pub use base::*; +pub use family::*; diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/base.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/base.rs index faa67b0ee..18807e568 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/base.rs @@ -1,48 +1,116 @@ -use crate::matmul::components::global::unloader::Unloader; -use crate::matmul::components::global::{Config as _, Loader}; -use crate::matmul::components::stage::single_buffer::{LhsBufferReader, RhsBufferReader}; -use crate::matmul::components::stage::TilingOrderConfig; -use crate::matmul::components::MatmulKernel; +use crate::matmul::components::global::output_loader::Unloader; +use crate::matmul::components::global::{ + GlobalConfig as _, GlobalMatmul, GlobalMatmulFamily, InputLoader, +}; +use crate::matmul::components::stage::single_buffer::{ + LhsBufferReader, LhsBufferReaderFamily, RhsBufferReader, RhsBufferReaderFamily, +}; +use crate::matmul::components::stage::{StageMatmul, TilingOrderConfig}; use crate::matmul::components::StageDim; use crate::matmul::components::{config::MatmulConfig, global::ZeroAccumulatorLoader}; use crate::matmul::components::{global, MatmulProblem}; -use crate::matmul::components::{stage, MatmulSpec}; +use crate::matmul::components::{stage, MatmulPrecision}; use crate::matmul::components::{Ident, MatrixLayout}; +use crate::matmul::components::{InvalidConfigError, MatmulConfigFactory}; use crate::matmul::kernels::matmul::AdvancedConfig; use crate::matmul::kernels::MatmulAvailabilityError; use crate::tensor::{ReadWrite, VirtualTensor}; +use super::loader::{LhsBufferLoader, RhsBufferLoader}; use cubecl_core as cubecl; use cubecl_core::prelude::*; use std::marker::PhantomData; -use super::loader::{LhsBufferLoader, RhsBufferLoader}; +pub struct SpecializedMatmulFamily { + _stage_matmul: PhantomData, +} + +impl GlobalMatmulFamily for SpecializedMatmulFamily +where + SMM: stage::StageMatmulFamily< + LhsReader = LhsBufferReaderFamily, + RhsReader = RhsBufferReaderFamily, + >, +{ + type Matmul = SpecializedMatmul>; +} + +impl MatmulConfigFactory for SpecializedMatmulFamily +where + SMM: stage::StageMatmulFamily, +{ + type Input = SMM::Input; + type Config = Config; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + if config.num_producers() == 0 { + return Err(Box::new("There are no producer planes. Make sure there are more planes than the underlying stage matmul requires.")); + } + if config.stage_dim(Ident::Lhs).num_tiles_y_dim() <= 1 { + return Err(Box::new("Producer-consumer needs at least 2 buffers.")); + } + + SMM::check_config(&config.to_smm_config()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + SMM::check_availability::(client, &config.smm_config) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let smm_config = SMM::make_config(input, problem, cube_dim, cube_count, advanced_config); + let size = SMM::size(&smm_config); + + Config::new( + smm_config, + problem.m as u32 % size.m != 0, + problem.n as u32 % size.n != 0, + problem.k as u32 % size.k != 0, + problem.lhs_layout, + problem.rhs_layout, + problem.lhs_line_size as u32, + problem.rhs_line_size as u32, + problem.out_line_size as u32, + cube_dim.y, + ) + } +} /// Performs matrix multiplication at the global level, with planes split between two roles: /// - First n planes are used in the stage matmul computation, with n the number of planes needed by the underlying stage matmul /// - Remaining planes load data to the stage /// /// Both roles alternate the buffer (tile index in dimension k) they are working on -pub struct Matmul> { - _ms: PhantomData, +pub struct SpecializedMatmul> { + _ms: PhantomData, _stage_matmul: PhantomData, } #[cube] -impl global::Matmul for Matmul +impl global::GlobalMatmul for SpecializedMatmul where - SMM: stage::Matmul< - MS::ES, - MS::EG, - MS::EA, - LhsReader = LhsBufferReader, - RhsReader = RhsBufferReader, + SMM: StageMatmul< + MP::ES, + MP::EG, + MP::EA, + LhsReader = LhsBufferReader, + RhsReader = RhsBufferReader, >, { - type LhsLoader = LhsBufferLoader; - type RhsLoader = RhsBufferLoader; + type Config = Config; + type LhsLoader = LhsBufferLoader; + type RhsLoader = RhsBufferLoader; type AccumulatorLoader = ZeroAccumulatorLoader; - type Out = Unloader; + type Out = Unloader; type Accumulator = SMM::Accumulator; fn execute( @@ -102,7 +170,7 @@ where } fn init_lhs_loader( - lhs: VirtualTensor, + lhs: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -119,7 +187,7 @@ where } fn init_rhs_loader( - rhs: VirtualTensor, + rhs: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -136,7 +204,7 @@ where } fn init_unloader( - out: VirtualTensor, + out: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -154,60 +222,25 @@ where } #[cube] -impl> Matmul { - fn is_consumer(#[comptime] config: ::Config) -> bool { - UNIT_POS_Y < config.num_consumers() - } -} - -impl MatmulKernel for Matmul -where - MS: MatmulSpec, - SMM: stage::Matmul, +impl< + MP: MatmulPrecision, + SMM: StageMatmul< + MP::ES, + MP::EG, + MP::EA, + LhsReader = LhsBufferReader, + RhsReader = RhsBufferReader, + >, + > SpecializedMatmul { - type Config = Config; - - fn check_config(config: Self::Config) { - assert!(config.num_producers() > 0, "There are no producer planes. Make sure there are more planes than the underlying stage matmul requires."); - assert!( - config.stage_dim(Ident::Lhs).num_tiles_y_dim() > 1, - "Producer-consumer needs at least 2 buffers." - ); - SMM::check_config(config.to_smm_config()); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - SMM::check_availability::(client) - } - - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - let smm_config = SMM::make_config(problem, cube_dim, cube_count, advanced_config); - - Config::new( - smm_config, - problem.m as u32 % SMM::M != 0, - problem.n as u32 % SMM::N != 0, - problem.k as u32 % SMM::K != 0, - problem.lhs_layout, - problem.rhs_layout, - problem.lhs_line_size as u32, - problem.rhs_line_size as u32, - problem.out_line_size as u32, - cube_dim.y, - ) + fn is_consumer(#[comptime] config: >::Config) -> bool { + UNIT_POS_Y < config.num_consumers() } } #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for the producer consumer global matmul -pub struct Config { +pub struct Config { smm_config: S, check_m_bounds: bool, check_n_bounds: bool, @@ -220,7 +253,7 @@ pub struct Config { num_planes: u32, } -impl global::Config for Config { +impl global::GlobalConfig for Config { type SmmConfig = S; fn to_smm_config(&self) -> Self::SmmConfig { @@ -280,9 +313,9 @@ impl global::Config for Config { } } -impl MatmulConfig for Config {} +impl MatmulConfig for Config {} -impl Config { +impl Config { #[allow(clippy::too_many_arguments)] pub fn new( smm_config: S, diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/loader.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/loader.rs index b6782e408..1f6c8c87f 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/loader.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/specialized/loader.rs @@ -1,11 +1,11 @@ use std::marker::PhantomData; use crate::matmul::components::config::InputIdent; -use crate::matmul::components::global::base::Config as _; +use crate::matmul::components::global::base::GlobalConfig as _; use crate::matmul::components::global::buffered::buffer_loading::BufferLoading; use crate::matmul::components::global::buffered::specialized; use crate::matmul::components::global::tensor_view::TensorReader; -use crate::matmul::components::global::Loader; +use crate::matmul::components::global::InputLoader; use crate::matmul::components::stage::single_buffer::{LhsBufferReader, RhsBufferReader}; use crate::matmul::components::stage::TilingOrderConfig; use crate::matmul::components::stage::{self, Stage}; @@ -15,7 +15,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[derive(CubeType)] -pub struct LhsBufferLoader { +pub struct LhsBufferLoader { pub tensor_view: TensorReader, pub stage: Stage, buffer_iter: u32, @@ -25,7 +25,7 @@ pub struct LhsBufferLoader { } #[derive(CubeType)] -pub struct RhsBufferLoader { +pub struct RhsBufferLoader { pub tensor_view: TensorReader, pub stage: Stage, buffer_iter: u32, @@ -35,7 +35,7 @@ pub struct RhsBufferLoader { } #[cube] -impl Loader> +impl InputLoader> for LhsBufferLoader { type StageReader = LhsBufferReader; @@ -66,7 +66,7 @@ impl Loader LhsBufferLoader { +impl LhsBufferLoader { pub fn new( tensor: VirtualTensor, x_offset: u32, @@ -90,7 +90,7 @@ impl LhsBufferLoader { } #[cube] -impl Loader> +impl InputLoader> for RhsBufferLoader { type StageReader = RhsBufferReader; @@ -121,7 +121,7 @@ impl Loader RhsBufferLoader { +impl RhsBufferLoader { pub fn new( tensor: VirtualTensor, x_offset: u32, @@ -145,7 +145,7 @@ impl RhsBufferLoader { } #[cube] -fn load_buffer( +fn load_buffer( buffer_iter: u32, tensor_view: &TensorReader, stage: &mut Stage, @@ -173,7 +173,7 @@ fn load_buffer( ); } -fn check_buffers_contiguous(ident: Ident, config: G) { +fn check_buffers_contiguous(ident: Ident, config: G) { match ident.as_input() { InputIdent::Lhs => { if let TilingOrderConfig::RowMajor = config.tiling_order(ident) { diff --git a/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs b/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs index 2e7ef1e1a..e0a4832e0 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs @@ -1,13 +1,17 @@ -use crate::matmul::components::global::unloader::Unloader; -use crate::matmul::components::global::{Config as _, Loader}; -use crate::matmul::components::stage::multi_buffer::{LhsReader, RhsReader}; -use crate::matmul::components::stage::TilingOrderConfig; -use crate::matmul::components::MatmulKernel; +use crate::matmul::components::global::output_loader::Unloader; +use crate::matmul::components::global::{ + GlobalConfig as _, GlobalMatmul, GlobalMatmulFamily, InputLoader, +}; +use crate::matmul::components::stage::multi_buffer::{ + LhsReader, LhsReaderFamily, RhsReader, RhsReaderFamily, +}; +use crate::matmul::components::stage::{StageMatmul, TilingOrderConfig}; use crate::matmul::components::StageDim; use crate::matmul::components::{config::MatmulConfig, global::ZeroAccumulatorLoader}; use crate::matmul::components::{global, MatmulProblem}; -use crate::matmul::components::{stage, MatmulSpec}; +use crate::matmul::components::{stage, InvalidConfigError}; use crate::matmul::components::{Ident, MatrixLayout}; +use crate::matmul::components::{MatmulConfigFactory, MatmulPrecision}; use crate::matmul::kernels::matmul::AdvancedConfig; use crate::matmul::kernels::MatmulAvailabilityError; use crate::tensor::{ReadWrite, VirtualTensor}; @@ -18,38 +22,106 @@ use std::marker::PhantomData; use super::loader::{LhsLoader, LoadingStrategy, RhsLoader}; +pub struct FullLoadMatmulFamily< + SMM: stage::StageMatmulFamily, + LL: LoadingStrategy, + RL: LoadingStrategy, +> { + _stage_matmul: PhantomData, + _lhs_loading: PhantomData, + _rhs_loading: PhantomData, +} + +impl GlobalMatmulFamily for FullLoadMatmulFamily +where + SMM: stage::StageMatmulFamily, + LL: LoadingStrategy, + RL: LoadingStrategy, +{ + type Matmul = + FullLoadMatmul, LL, RL>; +} + +impl MatmulConfigFactory for FullLoadMatmulFamily +where + SMM: stage::StageMatmulFamily, + LL: LoadingStrategy, + RL: LoadingStrategy, +{ + type Input = SMM::Input; + type Config = Config; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + LL::check(config, Ident::Lhs)?; + RL::check(config, Ident::Rhs)?; + SMM::check_config(&config.to_smm_config()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + SMM::check_availability::(client, &config.smm_config) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let smm_config = SMM::make_config(input, problem, cube_dim, cube_count, advanced_config); + let size = SMM::size(&smm_config); + + Config::new( + smm_config, + problem.m as u32 % size.m != 0, + problem.n as u32 % size.n != 0, + problem.k as u32 % size.k != 0, + problem.lhs_layout, + problem.rhs_layout, + problem.lhs_line_size as u32, + problem.rhs_line_size as u32, + problem.out_line_size as u32, + size.k, + ) + } +} + /// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities /// - All planes load data to the stage /// - All planes are used in the stage matmul computation -pub struct Matmul< - MS: MatmulSpec, - SMM: stage::Matmul, +pub struct FullLoadMatmul< + MP: MatmulPrecision, + SMM: StageMatmul, LL: LoadingStrategy, RL: LoadingStrategy, > { - _ms: PhantomData, + _ms: PhantomData, _stage_matmul: PhantomData, _lhs_loading: PhantomData, _rhs_loading: PhantomData, } #[cube] -impl global::Matmul for Matmul +impl GlobalMatmul for FullLoadMatmul where - SMM: stage::Matmul< - MS::ES, - MS::EG, - MS::EA, - LhsReader = LhsReader, - RhsReader = RhsReader, + SMM: StageMatmul< + MP::ES, + MP::EG, + MP::EA, + LhsReader = LhsReader, + RhsReader = RhsReader, >, LL: LoadingStrategy, RL: LoadingStrategy, { - type LhsLoader = LhsLoader; - type RhsLoader = RhsLoader; + type Config = Config; + type LhsLoader = LhsLoader; + type RhsLoader = RhsLoader; type AccumulatorLoader = ZeroAccumulatorLoader; - type Out = Unloader; + type Out = Unloader; type Accumulator = SMM::Accumulator; fn execute( @@ -60,7 +132,7 @@ where k_range: (u32, u32), #[comptime] config: Self::Config, ) { - let k_step = SMM::K; + let k_step = config.k_step; let range = k_range.1 - k_range.0; let num_loops = (range + k_step - 1) / k_step; @@ -101,7 +173,7 @@ where } fn init_lhs_loader( - lhs: VirtualTensor, + lhs: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -111,7 +183,7 @@ where } fn init_rhs_loader( - rhs: VirtualTensor, + rhs: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -121,7 +193,7 @@ where } fn init_unloader( - out: VirtualTensor, + out: VirtualTensor, x_offset: u32, y_offset: u32, batch_offset: u32, @@ -138,49 +210,9 @@ where } } -impl MatmulKernel for Matmul -where - SMM: stage::Matmul, - LL: LoadingStrategy, - RL: LoadingStrategy, -{ - type Config = Config; - - fn check_config(config: Self::Config) { - SMM::check_config(config.to_smm_config()); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - SMM::check_availability::(client) - } - - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - let smm_config = SMM::make_config(problem, cube_dim, cube_count, advanced_config); - - Config::new( - smm_config, - problem.m as u32 % SMM::M != 0, - problem.n as u32 % SMM::N != 0, - problem.k as u32 % SMM::K != 0, - problem.lhs_layout, - problem.rhs_layout, - problem.lhs_line_size as u32, - problem.rhs_line_size as u32, - problem.out_line_size as u32, - ) - } -} - #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for the full load matmul -pub struct Config { +pub struct Config { smm_config: S, check_m_bounds: bool, check_n_bounds: bool, @@ -190,9 +222,10 @@ pub struct Config { lhs_line_size: u32, rhs_line_size: u32, out_line_size: u32, + pub k_step: u32, } -impl global::Config for Config { +impl global::GlobalConfig for Config { type SmmConfig = S; fn to_smm_config(&self) -> Self::SmmConfig { @@ -252,9 +285,9 @@ impl global::Config for Config { } } -impl MatmulConfig for Config {} +impl MatmulConfig for Config {} -impl Config { +impl Config { #[allow(clippy::too_many_arguments)] pub fn new( smm_config: S, @@ -266,6 +299,7 @@ impl Config { lhs_line_size: u32, rhs_line_size: u32, out_line_size: u32, + k_step: u32, ) -> Self { Self { smm_config, @@ -277,6 +311,7 @@ impl Config { lhs_line_size, rhs_line_size, out_line_size, + k_step, } } } diff --git a/crates/cubecl-linalg/src/matmul/components/global/full_load/cyclic_loading.rs b/crates/cubecl-linalg/src/matmul/components/global/full_load/cyclic_loading.rs index 418c6b31c..84c16ab5d 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/full_load/cyclic_loading.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/full_load/cyclic_loading.rs @@ -1,9 +1,9 @@ use crate::matmul::components::global::tensor_view::TensorReader; -use crate::matmul::components::global::Config; +use crate::matmul::components::global::{GlobalConfig, LoadingValidation}; use crate::matmul::components::stage::{ ColMajorTiling, RowMajorTiling, TilingOrder, TilingOrderConfig, }; -use crate::matmul::components::{Ident, MatrixLayout}; +use crate::matmul::components::{Ident, InvalidConfigError, MatrixLayout}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -14,9 +14,35 @@ use super::loader::LoadingStrategy; /// iterating with steps determined by the plane's dimension. pub struct CyclicLoading {} +impl LoadingValidation for CyclicLoading { + fn check(config: &C, ident: Ident) -> Result<(), InvalidConfigError> { + let stage_dim = config.stage_dim(ident); + let line_size = config.global_line_size(ident); + + let num_stage_elements = stage_dim.total_elements(); + let total_units = config.num_planes() * config.plane_dim(); + let jump_length = total_units * line_size; + + if num_stage_elements % jump_length != 0 { + return Err(Box::new( + "Too many data will be loaded, resulting in out of bounds. + Try setting line size and number of planes so that jump_length divides num_stage_elements.", + )); + } + + if config.transpose_load(ident) && config.global_line_size(ident) != 1 { + return Err(Box::new( + "Line size for stage is not supported when transposing", + )); + } + + Ok(()) + } +} + #[cube] impl LoadingStrategy for CyclicLoading { - fn load_to_slice( + fn load_to_slice( read_view: &TensorReader, slice: &mut SliceMut>, #[comptime] ident: Ident, @@ -30,9 +56,6 @@ impl LoadingStrategy for CyclicLoading { let jump_length = comptime!(total_units * line_size); let num_loads_per_unit = num_stage_elements / jump_length; - #[allow(clippy::all)] - let _ = comptime!(check_jump_divides_well(num_stage_elements, jump_length)); - let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X; let unit_position_base = unit_id * line_size; @@ -64,55 +87,33 @@ impl LoadingStrategy for CyclicLoading { slice[unit_position / line_size] = Line::cast_from(line_read); } true => { - let slice_line_size = config.stage_line_size(ident); - - if comptime!(slice_line_size == 1) { - let tile_offset = nth_tile * tile_num_elements; - - let tile_size_x = config.stage_dim(ident).tile_size_x_dim(); - let tile_size_y = config.stage_dim(ident).tile_size_y_dim(); - - let (height, width) = match config.layout(ident) { - MatrixLayout::RowMajor => (tile_size_x, tile_size_y), - MatrixLayout::ColMajor => (tile_size_y, tile_size_x), - }; - - let global_strided_idx = pos_within_tile / width; - let global_contiguous_idx = pos_within_tile % width; - - let slice_strided_root = global_contiguous_idx; - let slice_contiguous_idx = global_strided_idx; - let slice_stride = height; - - #[unroll] - for iter in 0..config.global_line_size(ident) { - let slice_strided_idx = slice_strided_root + iter; - let elem = line_read[iter]; - slice[tile_offset - + slice_strided_idx * slice_stride - + slice_contiguous_idx] = Line::cast_from(elem); - } - } else { - #[allow(clippy::all)] - let _ = comptime!(unsupported_line_size(slice_line_size)); + let tile_offset = nth_tile * tile_num_elements; + + let tile_size_x = config.stage_dim(ident).tile_size_x_dim(); + let tile_size_y = config.stage_dim(ident).tile_size_y_dim(); + + let (height, width) = match config.layout(ident) { + MatrixLayout::RowMajor => (tile_size_x, tile_size_y), + MatrixLayout::ColMajor => (tile_size_y, tile_size_x), + }; + + let global_strided_idx = pos_within_tile / width; + let global_contiguous_idx = pos_within_tile % width; + + let slice_strided_root = global_contiguous_idx; + let slice_contiguous_idx = global_strided_idx; + let slice_stride = height; + + #[unroll] + for iter in 0..config.global_line_size(ident) { + let slice_strided_idx = slice_strided_root + iter; + let elem = line_read[iter]; + slice[tile_offset + + slice_strided_idx * slice_stride + + slice_contiguous_idx] = Line::cast_from(elem); } } } } } } - -fn unsupported_line_size(line_size: u32) { - panic!( - "Line size for stage is not supported when transposing. Got {:?}.", - line_size - ) -} - -fn check_jump_divides_well(num_stage_elements: u32, jump_length: u32) { - assert!( - num_stage_elements % jump_length == 0, - "Too many data will be loaded, resulting in out of bounds. - Try setting line size and number of planes so that jump_length divides num_stage_elements." - ); -} diff --git a/crates/cubecl-linalg/src/matmul/components/global/full_load/loader.rs b/crates/cubecl-linalg/src/matmul/components/global/full_load/loader.rs index bc0476a28..4bc489a20 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/full_load/loader.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/full_load/loader.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::matmul::components::global::full_load; use crate::matmul::components::global::tensor_view::TensorReader; -use crate::matmul::components::global::Loader; +use crate::matmul::components::global::{InputLoader, LoadingValidation}; use crate::matmul::components::stage::multi_buffer::{LhsReader, RhsReader}; use crate::matmul::components::stage::{self, Stage}; use crate::matmul::components::{global, Ident}; @@ -11,7 +11,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[derive(CubeType)] -pub struct LhsLoader { +pub struct LhsLoader { pub tensor_view: TensorReader, pub stage: Stage, _config: PhantomData, @@ -19,7 +19,7 @@ pub struct LhsLoader { +pub struct RhsLoader { pub tensor_view: TensorReader, pub stage: Stage, _config: PhantomData, @@ -27,8 +27,8 @@ pub struct RhsLoader - Loader> for LhsLoader +impl + InputLoader> for LhsLoader { type StageReader = LhsReader; @@ -51,8 +51,8 @@ impl } #[cube] -impl LhsLoader { - pub fn new( +impl LhsLoader { + pub fn new( tensor: VirtualTensor, x_offset: u32, y_offset: u32, @@ -72,8 +72,8 @@ impl LhsLoader - Loader> for RhsLoader +impl + InputLoader> for RhsLoader { type StageReader = RhsReader; @@ -96,8 +96,8 @@ impl } #[cube] -impl RhsLoader { - pub fn new( +impl RhsLoader { + pub fn new( tensor: VirtualTensor, x_offset: u32, y_offset: u32, @@ -117,8 +117,8 @@ impl RhsLoader( +pub trait LoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation { + fn load_to_slice( read_view: &TensorReader, slice: &mut SliceMut>, #[comptime] ident: Ident, diff --git a/crates/cubecl-linalg/src/matmul/components/global/full_load/tilewise_loading.rs b/crates/cubecl-linalg/src/matmul/components/global/full_load/tilewise_loading.rs index 9d7f7530a..f270de32e 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/full_load/tilewise_loading.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/full_load/tilewise_loading.rs @@ -1,9 +1,9 @@ use crate::matmul::components::global::tensor_view::TensorReader; -use crate::matmul::components::global::Config; +use crate::matmul::components::global::{GlobalConfig, LoadingValidation}; use crate::matmul::components::stage::{ ColMajorTiling, RowMajorTiling, TilingOrder, TilingOrderConfig, }; -use crate::matmul::components::Ident; +use crate::matmul::components::{FormattedConfigError, Ident, InvalidConfigError}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -14,9 +14,42 @@ use super::loader::LoadingStrategy; /// one plane per tile. pub struct TilewiseLoading {} +impl LoadingValidation for TilewiseLoading { + fn check(config: &C, ident: Ident) -> Result<(), InvalidConfigError> { + let stage_dim = config.stage_dim(ident); + let line_size = config.global_line_size(ident); + + let num_planes = config.num_planes(); + let num_tiles = stage_dim.num_tiles(); + + if num_planes != num_tiles { + return Err(FormattedConfigError::new(move || { + format!( + "Number of planes {:?} must equal number of tiles {:?} for tilewise loading.", + num_planes, num_tiles, + ) + })); + } + + if line_size != config.stage_line_size(ident) { + return Err(Box::new( + "Global and stage line sizes must match for tilewise loading.", + )); + } + + if config.transpose_load(ident) { + return Err(Box::new( + "Transpose load not yet supported in tilewise loading setup", + )); + } + + Ok(()) + } +} + #[cube] impl LoadingStrategy for TilewiseLoading { - fn load_to_slice( + fn load_to_slice( read_view: &TensorReader, slice: &mut SliceMut>, #[comptime] ident: Ident, @@ -25,12 +58,6 @@ impl LoadingStrategy for TilewiseLoading { let stage_dim = config.stage_dim(ident); let line_size = config.global_line_size(ident); - #[allow(clippy::all)] - let _ = comptime! { - check_num_planes(config.num_planes(), stage_dim.num_tiles()); - check_line_sizes(line_size, config.stage_line_size(ident)) - }; - let num_lines_per_tile = comptime!(stage_dim.tile_num_elements() / line_size); let nth_tile = UNIT_POS_Y; @@ -63,34 +90,7 @@ impl LoadingStrategy for TilewiseLoading { ); let offset = offset_base + pos_within_tile; - - match config.transpose_load(ident) { - false => slice[offset] = Line::cast_from(line_read), - true => { - #[allow(clippy::all)] - let _ = comptime!(unsupported_transpose_load()); - } - } + slice[offset] = Line::cast_from(line_read); } } } - -fn check_num_planes(num_planes: u32, num_tiles: u32) { - assert!( - num_planes == num_tiles, - "Number of planes {:?} must equal number of tiles {:?} for tilewise loading.", - num_planes, - num_tiles - ); -} - -fn check_line_sizes(global_line_size: u32, stage_line_size: u32) { - assert!( - global_line_size == stage_line_size, - "Global and stage line sizes must match for tilewise loading." - ); -} - -fn unsupported_transpose_load() { - panic!("Transpose load not yet supported in tilewise loading setup") -} diff --git a/crates/cubecl-linalg/src/matmul/components/global/mod.rs b/crates/cubecl-linalg/src/matmul/components/global/mod.rs index d1926861e..1318ad26e 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/mod.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/mod.rs @@ -5,8 +5,11 @@ pub mod tensor_view; mod accumulator_loader; mod base; +mod shared; mod tilewise_unloading; -pub mod unloader; + +pub mod output_loader; pub use accumulator_loader::*; pub use base::*; +pub use shared::*; diff --git a/crates/cubecl-linalg/src/matmul/components/global/unloader.rs b/crates/cubecl-linalg/src/matmul/components/global/output_loader.rs similarity index 86% rename from crates/cubecl-linalg/src/matmul/components/global/unloader.rs rename to crates/cubecl-linalg/src/matmul/components/global/output_loader.rs index 477b3b618..3a98efb8a 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/unloader.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/output_loader.rs @@ -13,10 +13,10 @@ pub struct Unloader { } #[cube] -impl global::Unloader for Unloader { +impl global::OutputLoader for Unloader { type StageWriter = Self; - fn as_stage_writer(this: Self) -> Self::StageWriter { + fn as_stage_writer(this: Self) -> Self::StageWriter { this } } @@ -37,7 +37,7 @@ impl Unloader { #[cube] impl StageWriter for Unloader { - fn write( + fn write( this: &mut Self, slice: Slice>, compute_plane_offset: u32, diff --git a/crates/cubecl-linalg/src/matmul/components/global/shared.rs b/crates/cubecl-linalg/src/matmul/components/global/shared.rs new file mode 100644 index 000000000..530be6cc3 --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/components/global/shared.rs @@ -0,0 +1,113 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::matmul::components::{ + stage::{self, TilingOrderConfig}, + Ident, MatmulConfig, MatrixLayout, StageDim, +}; + +#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] +/// Configuration for the pipelined global matmul +pub struct CommonGlobalConfig { + pub smm_config: S, + pub check_m_bounds: bool, + pub check_n_bounds: bool, + pub check_k_bounds: bool, + pub lhs_layout: MatrixLayout, + pub rhs_layout: MatrixLayout, + pub lhs_line_size: u32, + pub rhs_line_size: u32, + pub out_line_size: u32, + pub num_planes: u32, +} + +impl super::GlobalConfig for CommonGlobalConfig { + type SmmConfig = S; + + fn to_smm_config(&self) -> Self::SmmConfig { + self.smm_config + } + + fn global_line_size(&self, ident: Ident) -> u32 { + match ident { + Ident::Lhs => self.lhs_line_size, + Ident::Rhs => self.rhs_line_size, + Ident::Out => self.out_line_size, + } + } + + fn stage_line_size(&self, ident: Ident) -> u32 { + self.smm_config.line_size(ident) + } + + fn stage_dim(&self, ident: Ident) -> Box { + self.smm_config.stage_dim(ident) + } + + fn layout(&self, ident: Ident) -> MatrixLayout { + match ident { + Ident::Lhs => self.lhs_layout, + Ident::Rhs => self.rhs_layout, + Ident::Out => self.smm_config.layout(Ident::Out), + } + } + + fn num_planes(&self) -> u32 { + self.num_planes + } + + fn plane_dim(&self) -> u32 { + self.smm_config.plane_dim() + } + + fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { + self.smm_config.tiling_order(ident) + } + + fn check_m_bounds(&self) -> bool { + self.check_m_bounds + } + + fn check_n_bounds(&self) -> bool { + self.check_n_bounds + } + + fn check_k_bounds(&self) -> bool { + self.check_k_bounds + } + + fn transpose_load(&self, ident: Ident) -> bool { + self.layout(ident) != self.smm_config.layout(ident) + } +} + +impl MatmulConfig for CommonGlobalConfig {} + +impl CommonGlobalConfig { + #[allow(clippy::too_many_arguments)] + pub fn new( + smm_config: S, + check_m_bounds: bool, + check_n_bounds: bool, + check_k_bounds: bool, + lhs_layout: MatrixLayout, + rhs_layout: MatrixLayout, + lhs_line_size: u32, + rhs_line_size: u32, + out_line_size: u32, + num_planes: u32, + ) -> Self { + Self { + smm_config, + check_m_bounds, + check_n_bounds, + check_k_bounds, + lhs_layout, + rhs_layout, + lhs_line_size, + rhs_line_size, + out_line_size, + num_planes, + } + } +} diff --git a/crates/cubecl-linalg/src/matmul/components/global/tensor_view.rs b/crates/cubecl-linalg/src/matmul/components/global/tensor_view.rs index 995091a3d..41ea5cc40 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/tensor_view.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/tensor_view.rs @@ -83,7 +83,7 @@ impl TensorReader { /// # Note /// /// Out-of-bounds reads will be translated to zeros. - pub fn load_coalesced( + pub fn load_coalesced( &self, tile_x: u32, tile_y: u32, @@ -169,7 +169,7 @@ impl TensorWriter { /// Writes data into the tensor view at the specified coordinates (write_x, write_y). /// /// Each unit writes one line in a coalesced manner for improved efficiency, assuming row-major layout. - pub fn write_coalesced( + pub fn write_coalesced( &mut self, tile_x: u32, tile_y: u32, diff --git a/crates/cubecl-linalg/src/matmul/components/global/tilewise_unloading.rs b/crates/cubecl-linalg/src/matmul/components/global/tilewise_unloading.rs index e374e4129..9ab3536e0 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/tilewise_unloading.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/tilewise_unloading.rs @@ -1,5 +1,5 @@ use crate::matmul::components::global::tensor_view::TensorWriter; -use crate::matmul::components::global::Config; +use crate::matmul::components::global::GlobalConfig; use crate::matmul::components::Ident; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -11,7 +11,7 @@ pub struct TilewiseUnloading {} #[cube] impl TilewiseUnloading { - pub fn unload_from_slice( + pub fn unload_from_slice( write_view: &mut TensorWriter, slice: Slice>, tile_x: u32, diff --git a/crates/cubecl-linalg/src/matmul/components/problem.rs b/crates/cubecl-linalg/src/matmul/components/problem.rs index 943344e81..6da9c95f6 100644 --- a/crates/cubecl-linalg/src/matmul/components/problem.rs +++ b/crates/cubecl-linalg/src/matmul/components/problem.rs @@ -38,7 +38,10 @@ impl MatmulProblem { /// /// - If dimensions of the problem are larger than allowed by the config /// - If line sizes do not divide well the dimension in which they are aligned - pub fn check_config(&self, config: &B) -> Result<(), MatmulInvalidProblem> { + pub fn check_config( + &self, + config: &B, + ) -> Result<(), MatmulInvalidProblem> { if self.m > config.max_m() as usize { return Err(MatmulInvalidProblem::ExceededMSize { m: self.m as u32, diff --git a/crates/cubecl-linalg/src/matmul/components/spec.rs b/crates/cubecl-linalg/src/matmul/components/spec.rs index 71c42149d..a76046923 100644 --- a/crates/cubecl-linalg/src/matmul/components/spec.rs +++ b/crates/cubecl-linalg/src/matmul/components/spec.rs @@ -7,9 +7,6 @@ use super::global::args::{MatmulArgs, TensorArgs}; /// Matrix multiplication spec definiting each element types used in the computation as well as /// how the arguments are passed to the kernel. pub trait MatmulSpec: Send + Sync + Clone + 'static { - /// The plane size used by this kernel. - const PLANE_DIM: u32; - /// Element type of each input and output tensor of the kernel. type EG: Numeric; /// Element type of the intermediate representation of the inputs. @@ -17,14 +14,30 @@ pub trait MatmulSpec: Send + Sync + Clone + 'static { /// Element type of the intermediate representation of the output accumulator. type EA: Numeric; /// How the input and output tensors are passed as arguments. - type Args: MatmulArgs; + type Args: MatmulArgs; +} + +/// Matrix multiplication precisions. +pub trait MatmulPrecision: Send + Sync + Clone + 'static { + /// Element type of each input and output tensor of the kernel. + type EG: Numeric; + /// Element type of the intermediate representation of the inputs. + type ES: Numeric; + /// Element type of the intermediate representation of the output accumulator. + type EA: Numeric; +} + +impl MatmulPrecision for (EG, ES, EA) { + type EG = EG; + type ES = ES; + type EA = EA; } /// Input argument -pub type InputArg = as MatmulArgs>>::Input; +pub type InputArg = as MatmulArgs>::Input>; /// Output argument -pub type OutputArg = as MatmulArgs>>::Output; +pub type OutputArg = as MatmulArgs>::Output>; /// Input runtime argument pub type InputRuntimeArg<'a, MS, R> = as LaunchArg>::RuntimeArg<'a, R>; @@ -37,19 +50,18 @@ type Args = ::Args; /// Specification for a simple standard matmul using global tensor as inputs. #[derive(Clone)] -pub struct SingleMatmulSpec { +pub struct SingleMatmulSpec { _eg: PhantomData, _es: PhantomData, _ea: PhantomData, + _args: PhantomData, } -impl MatmulSpec - for SingleMatmulSpec +impl MatmulSpec + for SingleMatmulSpec { - const PLANE_DIM: u32 = PLANE_DIM; - type EG = EG; type ES = ES; type EA = EA; - type Args = TensorArgs; + type Args = Args; } diff --git a/crates/cubecl-linalg/src/matmul/components/stage/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/base.rs index c9411d1f1..7e94cf1b6 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/base.rs @@ -1,13 +1,38 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::StageDim; +use crate::matmul::components::tile::TileConfig; use crate::matmul::components::{config::MatmulConfig, global::AccumulatorLoader}; -use crate::matmul::components::{global, tile, MatmulKernel}; +use crate::matmul::components::{global, MatmulConfigFactory}; use crate::matmul::components::{Ident, MatrixLayout}; +use crate::matmul::components::{MatmulSize, StageDim}; use super::tiling_order::TilingOrderConfig; +pub trait ReaderFamily { + type Reader: CubeType; +} + +pub trait StageMatmulFamily: + MatmulConfigFactory + Send + Sync + 'static +{ + type LhsReader: ReaderFamily; + type RhsReader: ReaderFamily; + + fn size(config: &Self::Config) -> MatmulSize; + /// Return the number of matmuls computed by the stage. + fn num(config: &Self::Config) -> MatmulSize; + + type Matmul: StageMatmul< + I, + O, + Acc, + Config = Self::Config, + LhsReader = ::Reader, + RhsReader = ::Reader, + >; +} + #[cube] /// Provides matrix multiplication operations at the stage level. /// @@ -23,15 +48,8 @@ use super::tiling_order::TilingOrderConfig; /// - Data given as inputs by stage readers must always be valid. If the actual matrix multiplication /// should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand. /// - Enough planes are launched to perform the whole computation -pub trait Matmul: - 'static + Send + Sync + MatmulKernel -{ - /// Number of rows of LHS - const M: u32; - /// Number of columns of RHS - const N: u32; - /// Common dimension of LHS and RHS - const K: u32; +pub trait StageMatmul: 'static + Send + Sync { + type Config: StageConfig; /// Contains the matrix multiplication output, that can be shared across the different planes of the cube. /// The same Accumulator will be added to across multiple executions of the stage matmul. @@ -56,7 +74,7 @@ pub trait Matmul: fn init_tile_inputs(#[comptime] config: Self::Config) -> (Self::LhsTile, Self::RhsTile); /// Reads the result of the accumulator and hands it to the stage writer - fn read_accumulator, G: global::Config>( + fn read_accumulator, G: global::GlobalConfig>( acc: &Self::Accumulator, out: &mut Out, #[comptime] stage_config: Self::Config, @@ -83,7 +101,7 @@ pub trait Matmul: pub trait StageReader: CubeType { /// Hands a portion of data from the stage, whose location is function of the /// plane, buffer and accumulator indexes. - fn read_tile( + fn read_tile( this: &Self, compute_plane_offset: u32, buffer_offset: u32, @@ -98,7 +116,7 @@ pub trait StageReader: CubeType { pub trait StageWriter: CubeType + 'static + Send + Sync { /// Writes the given slice to global memory, at a position that depends on /// plane and accumulator indexes. - fn write( + fn write( this: &mut Self, slice: Slice>, compute_plane_offset: u32, @@ -108,9 +126,9 @@ pub trait StageWriter: CubeType + 'static + Send + Sync { } /// Configuration for the Stage matmul (SMM) level -pub trait Config: MatmulConfig { +pub trait StageConfig: MatmulConfig { /// Underlying Tile matmul config - type TmmConfig: tile::Config; + type TmmConfig: TileConfig; /// Convert itself to the underlying tile matmul config fn to_tmm_config(self) -> Self::TmmConfig; @@ -132,45 +150,6 @@ pub trait Config: MatmulConfig { /// Returns the order in which tiles should be loaded to the stage fn tiling_order(&self, ident: Ident) -> TilingOrderConfig; -} -pub trait StageSize: 'static + Send + Sync { - const NUM_M: u32; - const NUM_N: u32; - const NUM_K: u32; + fn num_stages(&self) -> &MatmulSize; } - -macro_rules! create_cmma_stage { - ($name:ident, $m:expr, $n:expr, $k:expr) => { - pub struct $name; - - impl StageSize for $name { - const NUM_M: u32 = $m; - const NUM_N: u32 = $n; - const NUM_K: u32 = $k; - } - }; -} - -// This list is not exhaustive. Add what you need. -create_cmma_stage!(S1x1x1, 1, 1, 1); -create_cmma_stage!(S1x1x2, 1, 1, 2); -create_cmma_stage!(S1x1x3, 1, 1, 3); -create_cmma_stage!(S1x2x1, 1, 2, 1); -create_cmma_stage!(S1x2x2, 1, 2, 2); -create_cmma_stage!(S2x1x1, 2, 1, 1); -create_cmma_stage!(S2x1x2, 2, 1, 2); -create_cmma_stage!(S2x2x1, 2, 2, 1); -create_cmma_stage!(S2x2x2, 2, 2, 2); -create_cmma_stage!(S4x4x1, 4, 4, 1); -create_cmma_stage!(S4x4x2, 4, 4, 2); -create_cmma_stage!(S4x4x4, 4, 4, 4); -create_cmma_stage!(S4x2x4, 4, 2, 4); -create_cmma_stage!(S8x1x1, 8, 1, 1); -create_cmma_stage!(S8x2x2, 8, 2, 2); -create_cmma_stage!(S8x4x1, 8, 4, 1); -create_cmma_stage!(S8x4x2, 8, 4, 2); -create_cmma_stage!(S2x2x8, 2, 2, 8); -create_cmma_stage!(S16x4x4, 16, 4, 4); -create_cmma_stage!(S8x8x1, 8, 8, 1); -create_cmma_stage!(S8x8x2, 8, 8, 2); diff --git a/crates/cubecl-linalg/src/matmul/components/stage/mod.rs b/crates/cubecl-linalg/src/matmul/components/stage/mod.rs index a3496f1c4..5d6dceb70 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/mod.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/mod.rs @@ -2,9 +2,12 @@ pub mod multi_buffer; pub mod single_buffer; mod base; +pub(super) mod shared; mod staging; mod tiling_order; pub use base::*; pub use staging::Stage; pub use tiling_order::*; + +pub use shared::CommonStageInput; diff --git a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs index 736dc699f..eb56ca5e0 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs @@ -3,22 +3,101 @@ use std::marker::PhantomData; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::config::InputIdent; -use crate::matmul::components::stage::{StageSize, TilingOrderConfig}; -use crate::matmul::components::{global::AccumulatorLoader, stage::base::Matmul as _}; -use crate::matmul::components::{LhsStageDim, OutStageDim, RhsStageDim}; +use crate::matmul::components::global::AccumulatorLoader; +use crate::matmul::components::stage::shared::{ + stage_matmul_size, CommonStageConfig, CommonStageInput, +}; +use crate::matmul::components::stage::{StageMatmul, StageMatmulFamily}; +use crate::matmul::components::tile::TileMatmulFamily; +use crate::matmul::components::{ + InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatmulSize, +}; use crate::matmul::kernels::MatmulAvailabilityError; use crate::matmul::{ components::{ - config::MatmulConfig, global, - stage::{self, Config as _, StageWriter}, - tile, Ident, MatmulKernel, MatmulProblem, MatrixLayout, StageDim, + stage::{StageConfig as _, StageWriter}, + tile, Ident, MatmulProblem, }, kernels::matmul::{create_stage_dim, AdvancedConfig}, }; use super::reader::{LhsReader, RhsReader}; +use super::{LhsReaderFamily, RhsReaderFamily}; + +pub struct MultiBufferMatmulFamily { + _instruction: PhantomData, +} + +impl StageMatmulFamily for MultiBufferMatmulFamily { + fn size(config: &Self::Config) -> MatmulSize { + let tmm_config = config.to_tmm_config(); + stage_matmul_size::(&tmm_config, &config.num_stage) + } + + fn num(config: &Self::Config) -> MatmulSize { + config.num_stage + } + + type LhsReader = LhsReaderFamily; + type RhsReader = RhsReaderFamily; + type Matmul = + MultiBufferMatmul>; +} + +impl MatmulConfigFactory for MultiBufferMatmulFamily { + type Input = CommonStageInput; + type Config = CommonStageConfig; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + check_num_planes( + config.stage_dim(Ident::Lhs).num_tiles_x_dim(), + config.num_planes(), + )?; + TMM::check_config(&config.to_tmm_config()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + TMM::check_availability::(client, &config.tmm_config) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let tile = input.tile; + let tmm_config = TMM::make_config(tile, problem, cube_dim, cube_count, advanced_config); + let tmm_size = TMM::size(&tmm_config); + let stage_size = stage_matmul_size::(&tmm_config, &input.num_stages); + + let (tile_m, tile_n, tile_k) = (tmm_size.m, tmm_size.n, tmm_size.k); + let (lhs_stage_dim, rhs_stage_dim, out_stage_dim) = create_stage_dim( + stage_size.m, + stage_size.n, + stage_size.k, + tile_m, + tile_n, + tile_k, + ); + + CommonStageConfig::new( + tmm_config, + input.num_stages, + lhs_stage_dim, + rhs_stage_dim, + out_stage_dim, + cube_dim.y, + advanced_config.lhs_tiling_order, + advanced_config.rhs_tiling_order, + ) + } +} /// Performs matrix multiplication at the stage level, where each plane is responsible for a row of tiles: /// - One plane per tile in m dimension, @@ -26,26 +105,23 @@ use super::reader::{LhsReader, RhsReader}; /// /// # Assumptions /// - There are as many planes as the stage size in m -pub struct Matmul, SS: StageSize> { +pub struct MultiBufferMatmul> { _input_precision: PhantomData, _output_precision: PhantomData, _accumulator_precision: PhantomData, _instruction: PhantomData, - _block_size: PhantomData, } #[cube] -impl stage::Matmul for Matmul +impl StageMatmul for MultiBufferMatmul where I: Numeric, O: Numeric, EA: Numeric, - TMM: tile::Matmul, - SS: StageSize, + TMM: tile::TileMatmul, { - const M: u32 = SS::NUM_M * TMM::M; - const N: u32 = SS::NUM_N * TMM::N; - const K: u32 = SS::NUM_K * TMM::K; + type Config = CommonStageConfig; + type LhsReader = LhsReader; type RhsReader = RhsReader; type Accumulator = Sequence; @@ -61,7 +137,7 @@ where #[comptime] config: Self::Config, ) { #[unroll] - for buffer_iter in 0..SS::NUM_K { + for buffer_iter in 0..config.num_stage.k { let tile_lhs = LhsReader::read_tile::(lhs_reader, UNIT_POS_Y, buffer_iter, config); TMM::fill_lhs(&tile_lhs, lhs_tile, config.to_tmm_config()); @@ -89,7 +165,7 @@ where ) } - fn read_accumulator, G: global::Config>( + fn read_accumulator, G: global::GlobalConfig>( acc: &Self::Accumulator, out: &mut SW, #[comptime] stage_config: Self::Config, @@ -124,7 +200,7 @@ where let mut acc = Sequence::::new(); #[unroll] - for _ in 0..SS::NUM_N { + for _ in 0..config.num_stage.n { acc.push(TMM::init_accumulator(config.to_tmm_config())); } @@ -133,7 +209,7 @@ where fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config) { #[unroll] - for i in 0..SS::NUM_N { + for i in 0..config.num_stage.n { TMM::zero_accumulator(acc.index_mut(i), config.to_tmm_config()); } } @@ -144,141 +220,21 @@ where #[comptime] config: Self::Config, ) { #[unroll] - for i in 0..SS::NUM_N { + for i in 0..config.num_stage.n { let acc = acc.index_mut(i); L::load::(loader, acc, i, config.to_tmm_config()); } } } -impl MatmulKernel for Matmul -where - I: Numeric, - O: Numeric, - Acc: Numeric, - TMM: tile::Matmul, - SS: StageSize, -{ - type Config = Config; - - fn check_config(config: Self::Config) { - comptime!(check_num_planes( - config.stage_dim(Ident::Lhs).num_tiles_x_dim(), - config.num_planes() - )); - TMM::check_config(config.to_tmm_config()); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - TMM::check_availability::(client) +fn check_num_planes( + expected_num_planes: u32, + actual_num_planes: u32, +) -> Result<(), InvalidConfigError> { + if expected_num_planes != actual_num_planes { + return Err(Box::new("Error: Expected {expected_num_planes} planes, but found {actual_num_planes}. + The number of planes is equal to cube dimension y which should be set to {expected_num_planes}.")); } - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - let tmm_config = TMM::make_config(problem, cube_dim, cube_count, advanced_config); - - let (stage_m, stage_n, stage_k) = (Self::M, Self::N, Self::K); - let (tile_m, tile_n, tile_k) = (TMM::M, TMM::N, TMM::K); - let (lhs_stage_dim, rhs_stage_dim, out_stage_dim) = - create_stage_dim(stage_m, stage_n, stage_k, tile_m, tile_n, tile_k); - - Config::new( - tmm_config, - lhs_stage_dim, - rhs_stage_dim, - out_stage_dim, - cube_dim.y, - advanced_config.lhs_tiling_order, - advanced_config.rhs_tiling_order, - ) - } -} - -fn check_num_planes(expected_num_planes: u32, actual_num_planes: u32) { - assert_eq!( - expected_num_planes, actual_num_planes, - "Error: Expected {expected_num_planes} planes, but found {actual_num_planes}. - The number of planes is equal to cube dimension y which should be set to {expected_num_planes}.", - ); -} - -#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] -/// Configuration for the multi buffer matmul -pub struct Config { - tmm_config: T, - lhs_stage_dim: LhsStageDim, - rhs_stage_dim: RhsStageDim, - out_stage_dim: OutStageDim, - num_planes: u32, - lhs_tiling_order: TilingOrderConfig, - rhs_tiling_order: TilingOrderConfig, -} - -impl stage::Config for Config { - type TmmConfig = T; - - fn to_tmm_config(self) -> Self::TmmConfig { - self.tmm_config - } - - fn line_size(&self, ident: Ident) -> u32 { - self.tmm_config.line_size(ident) - } - - fn stage_dim(&self, ident: Ident) -> Box { - match ident { - Ident::Lhs => Box::new(self.lhs_stage_dim), - Ident::Rhs => Box::new(self.rhs_stage_dim), - Ident::Out => Box::new(self.out_stage_dim), - } - } - - fn layout(&self, ident: Ident) -> MatrixLayout { - self.tmm_config.layout(ident) - } - - fn num_planes(&self) -> u32 { - self.num_planes - } - - fn plane_dim(&self) -> u32 { - self.tmm_config.plane_dim() - } - - fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { - match ident.as_input() { - InputIdent::Lhs => self.lhs_tiling_order, - InputIdent::Rhs => self.rhs_tiling_order, - } - } -} - -impl MatmulConfig for Config {} - -impl Config { - pub fn new( - tmm_config: T, - lhs_stage_dim: LhsStageDim, - rhs_stage_dim: RhsStageDim, - out_stage_dim: OutStageDim, - num_planes: u32, - lhs_tiling_order: TilingOrderConfig, - rhs_tiling_order: TilingOrderConfig, - ) -> Self { - Self { - tmm_config, - lhs_stage_dim, - rhs_stage_dim, - out_stage_dim, - num_planes, - lhs_tiling_order, - rhs_tiling_order, - } - } + Ok(()) } diff --git a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/reader.rs b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/reader.rs index 4e18df56f..8e845142e 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/reader.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/reader.rs @@ -1,6 +1,8 @@ -use crate::matmul::components::stage::multi_buffer; +use crate::matmul::components::stage::shared::CommonStageConfig; +use crate::matmul::components::stage::ReaderFamily; use crate::matmul::components::stage::Stage; -use crate::matmul::components::{tile, Ident}; +use crate::matmul::components::tile::TileConfig; +use crate::matmul::components::Ident; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -16,15 +18,26 @@ pub struct RhsReader { pub stage: Stage, } +pub struct LhsReaderFamily; +pub struct RhsReaderFamily; + +impl ReaderFamily for LhsReaderFamily { + type Reader = LhsReader; +} + +impl ReaderFamily for RhsReaderFamily { + type Reader = RhsReader; +} + #[cube] impl LhsReader { - pub fn read_tile( + pub fn read_tile( this: &Self, compute_plane_offset: u32, buffer_offset: u32, - #[comptime] config: multi_buffer::Config, + #[comptime] config: CommonStageConfig, ) -> Slice> { - this.stage.get_tile::>( + this.stage.get_tile::>( compute_plane_offset, buffer_offset, Ident::Lhs, @@ -35,13 +48,13 @@ impl LhsReader { #[cube] impl RhsReader { - pub fn read_tile( + pub fn read_tile( this: &Self, buffer_offset: u32, accumulator_offset: u32, - #[comptime] config: multi_buffer::Config, + #[comptime] config: CommonStageConfig, ) -> Slice> { - this.stage.get_tile::>( + this.stage.get_tile::>( buffer_offset, accumulator_offset, Ident::Rhs, diff --git a/crates/cubecl-linalg/src/matmul/components/stage/shared.rs b/crates/cubecl-linalg/src/matmul/components/stage/shared.rs new file mode 100644 index 000000000..2e4277394 --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/components/stage/shared.rs @@ -0,0 +1,111 @@ +use cubecl::prelude::*; +use cubecl_core as cubecl; + +use crate::matmul::components::{ + tile::{TileConfig, TileMatmulFamily}, + Ident, InputIdent, LhsStageDim, MatmulConfig, MatmulSize, MatrixLayout, OutStageDim, + RhsStageDim, StageDim, +}; + +use super::{StageConfig, TilingOrderConfig}; + +pub struct CommonStageInput { + pub tile: TMM::Input, + pub num_stages: MatmulSize, +} + +pub(super) fn stage_matmul_size( + config: &TMM::Config, + num_stage: &MatmulSize, +) -> MatmulSize { + let size = TMM::size(config); + + MatmulSize { + m: num_stage.m * size.m, + n: num_stage.n * size.n, + k: num_stage.k * size.k, + } +} + +#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] +/// Configuration for the single buffer matmul +pub struct CommonStageConfig { + pub tmm_config: T, + pub num_stage: MatmulSize, + pub lhs_stage_dim: LhsStageDim, + pub rhs_stage_dim: RhsStageDim, + pub out_stage_dim: OutStageDim, + pub num_planes: u32, + pub lhs_tiling_order: TilingOrderConfig, + pub rhs_tiling_order: TilingOrderConfig, +} + +impl StageConfig for CommonStageConfig { + type TmmConfig = T; + + fn to_tmm_config(self) -> Self::TmmConfig { + self.tmm_config + } + + fn line_size(&self, ident: Ident) -> u32 { + self.tmm_config.line_size(ident) + } + + fn stage_dim(&self, ident: Ident) -> Box { + match ident { + Ident::Lhs => Box::new(self.lhs_stage_dim), + Ident::Rhs => Box::new(self.rhs_stage_dim), + Ident::Out => Box::new(self.out_stage_dim), + } + } + + fn layout(&self, ident: Ident) -> MatrixLayout { + self.tmm_config.layout(ident) + } + + fn num_planes(&self) -> u32 { + self.num_planes + } + + fn plane_dim(&self) -> u32 { + self.tmm_config.plane_dim() + } + + fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { + match ident.as_input() { + InputIdent::Lhs => self.lhs_tiling_order, + InputIdent::Rhs => self.rhs_tiling_order, + } + } + + fn num_stages(&self) -> &MatmulSize { + &self.num_stage + } +} + +impl MatmulConfig for CommonStageConfig {} + +impl CommonStageConfig { + #[allow(clippy::too_many_arguments)] + pub fn new( + tmm_config: T, + num_stage: MatmulSize, + lhs_stage_dim: LhsStageDim, + rhs_stage_dim: RhsStageDim, + out_stage_dim: OutStageDim, + num_planes: u32, + lhs_tiling_order: TilingOrderConfig, + rhs_tiling_order: TilingOrderConfig, + ) -> Self { + Self { + num_stage, + tmm_config, + lhs_stage_dim, + rhs_stage_dim, + out_stage_dim, + num_planes, + lhs_tiling_order, + rhs_tiling_order, + } + } +} diff --git a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs index 7fa2ab100..22ba4343f 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs @@ -3,22 +3,97 @@ use std::marker::PhantomData; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::config::InputIdent; -use crate::matmul::components::stage::base::Matmul as _; -use crate::matmul::components::stage::{StageSize, TilingOrderConfig}; -use crate::matmul::components::{LhsStageDim, OutStageDim, RhsStageDim}; +use crate::matmul::components::stage::shared::{ + stage_matmul_size, CommonStageConfig, CommonStageInput, +}; +use crate::matmul::components::stage::StageMatmulFamily; +use crate::matmul::components::tile::{TileMatmul, TileMatmulFamily}; +use crate::matmul::components::{InvalidConfigError, MatmulPrecision, MatmulSize}; use crate::matmul::kernels::MatmulAvailabilityError; use crate::matmul::{ components::{ - config::MatmulConfig, global::{self, AccumulatorLoader}, - stage::{self, Config as _, StageWriter}, - tile, Ident, MatmulKernel, MatmulProblem, MatrixLayout, StageDim, + stage::{self, StageConfig as _, StageWriter}, + Ident, MatmulConfigFactory, MatmulProblem, StageDim, }, kernels::matmul::{create_stage_dim, AdvancedConfig}, }; -use super::{LhsBufferReader, RhsBufferReader}; +use super::{LhsBufferReader, LhsBufferReaderFamily, RhsBufferReader, RhsBufferReaderFamily}; + +pub struct SingleBufferMatmulFamily { + _instruction: PhantomData, +} + +impl StageMatmulFamily for SingleBufferMatmulFamily { + fn size(config: &Self::Config) -> MatmulSize { + let tmm_config = config.to_tmm_config(); + stage_matmul_size::(&tmm_config, &config.num_stage) + } + + fn num(config: &Self::Config) -> MatmulSize { + config.num_stage + } + + type LhsReader = LhsBufferReaderFamily; + type RhsReader = RhsBufferReaderFamily; + type Matmul = + SingleBufferMatmul>; +} + +impl MatmulConfigFactory for SingleBufferMatmulFamily +where + TMM: TileMatmulFamily, +{ + type Input = CommonStageInput; + type Config = CommonStageConfig; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + TMM::check_config(&config.to_tmm_config()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + TMM::check_availability::(client, &config.tmm_config) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let tile = input.tile; + + let tmm_config = TMM::make_config(tile, problem, cube_dim, cube_count, advanced_config); + let tmm_size = TMM::size(&tmm_config); + let stage_size = stage_matmul_size::(&tmm_config, &input.num_stages); + + let (tile_m, tile_n, tile_k) = (tmm_size.m, tmm_size.n, tmm_size.k); + let (lhs_stage_dim, rhs_stage_dim, out_stage_dim) = create_stage_dim( + stage_size.m, + stage_size.n, + stage_size.k, + tile_m, + tile_n, + tile_k, + ); + + CommonStageConfig::new( + tmm_config, + input.num_stages, + lhs_stage_dim, + rhs_stage_dim, + out_stage_dim, + lhs_stage_dim.num_tiles_x_dim(), + advanced_config.lhs_tiling_order, + advanced_config.rhs_tiling_order, + ) + } +} /// Performs matrix multiplication at the stage level, where each plane is responsible for a row of tiles: /// - One plane per tile in m dimension, @@ -28,26 +103,22 @@ use super::{LhsBufferReader, RhsBufferReader}; /// /// # Assumptions /// - There are at least as many planes as the stage size in m -pub struct Matmul, SS: StageSize> { +pub struct SingleBufferMatmul> { _input_precision: PhantomData, _output_precision: PhantomData, _accumulator_precision: PhantomData, _instruction: PhantomData, - _block_size: PhantomData, } #[cube] -impl stage::Matmul for Matmul +impl stage::StageMatmul for SingleBufferMatmul where I: Numeric, O: Numeric, EA: Numeric, - TMM: tile::Matmul, - SS: StageSize, + TMM: TileMatmul, { - const M: u32 = SS::NUM_M * TMM::M; - const N: u32 = SS::NUM_N * TMM::N; - const K: u32 = SS::NUM_K * TMM::K; + type Config = CommonStageConfig; type LhsReader = LhsBufferReader; type RhsReader = RhsBufferReader; type Accumulator = Sequence; @@ -87,7 +158,7 @@ where let mut accumulators = Sequence::::new(); #[unroll] - for _ in 0..SS::NUM_N { + for _ in 0..config.num_stage.n { accumulators.push(TMM::init_accumulator(config.to_tmm_config())); } @@ -96,7 +167,7 @@ where fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config) { #[unroll] - for i in 0..SS::NUM_N { + for i in 0..config.num_stage.n { TMM::zero_accumulator(acc.index_mut(i), config.to_tmm_config()); } } @@ -107,13 +178,13 @@ where #[comptime] config: Self::Config, ) { #[unroll] - for i in 0..SS::NUM_N { + for i in 0..config.num_stage.n { let acc = acc.index_mut(i); L::load::(loader, acc, i, config.to_tmm_config()); } } - fn read_accumulator, G: global::Config>( + fn read_accumulator, G: global::GlobalConfig>( acc: &Self::Accumulator, out: &mut SW, #[comptime] stage_config: Self::Config, @@ -144,123 +215,3 @@ where } } } - -impl MatmulKernel for Matmul -where - I: Numeric, - O: Numeric, - Acc: Numeric, - TMM: tile::Matmul, - SS: StageSize, -{ - type Config = Config; - - fn check_config(config: Self::Config) { - TMM::check_config(config.to_tmm_config()); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - TMM::check_availability::(client) - } - - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - let tmm_config = TMM::make_config(problem, cube_dim, cube_count, advanced_config); - - let (stage_m, stage_n, stage_k) = (Self::M, Self::N, Self::K); - let (tile_m, tile_n, tile_k) = (TMM::M, TMM::N, TMM::K); - let (lhs_stage_dim, rhs_stage_dim, out_stage_dim) = - create_stage_dim(stage_m, stage_n, stage_k, tile_m, tile_n, tile_k); - - Config::new( - tmm_config, - lhs_stage_dim, - rhs_stage_dim, - out_stage_dim, - lhs_stage_dim.num_tiles_x_dim(), - advanced_config.lhs_tiling_order, - advanced_config.rhs_tiling_order, - ) - } -} - -#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] -/// Configuration for the single buffer matmul -pub struct Config { - tmm_config: T, - lhs_stage_dim: LhsStageDim, - rhs_stage_dim: RhsStageDim, - out_stage_dim: OutStageDim, - num_planes: u32, - lhs_tiling_order: TilingOrderConfig, - rhs_tiling_order: TilingOrderConfig, -} - -impl stage::Config for Config { - type TmmConfig = T; - - fn to_tmm_config(self) -> Self::TmmConfig { - self.tmm_config - } - - fn line_size(&self, ident: Ident) -> u32 { - self.tmm_config.line_size(ident) - } - - fn stage_dim(&self, ident: Ident) -> Box { - match ident { - Ident::Lhs => Box::new(self.lhs_stage_dim), - Ident::Rhs => Box::new(self.rhs_stage_dim), - Ident::Out => Box::new(self.out_stage_dim), - } - } - - fn layout(&self, ident: Ident) -> MatrixLayout { - self.tmm_config.layout(ident) - } - - fn num_planes(&self) -> u32 { - self.num_planes - } - - fn plane_dim(&self) -> u32 { - self.tmm_config.plane_dim() - } - - fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { - match ident.as_input() { - InputIdent::Lhs => self.lhs_tiling_order, - InputIdent::Rhs => self.rhs_tiling_order, - } - } -} - -impl MatmulConfig for Config {} - -impl Config { - pub fn new( - tmm_config: T, - lhs_stage_dim: LhsStageDim, - rhs_stage_dim: RhsStageDim, - out_stage_dim: OutStageDim, - num_planes: u32, - lhs_tiling_order: TilingOrderConfig, - rhs_tiling_order: TilingOrderConfig, - ) -> Self { - Self { - tmm_config, - lhs_stage_dim, - rhs_stage_dim, - out_stage_dim, - num_planes, - lhs_tiling_order, - rhs_tiling_order, - } - } -} diff --git a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/reader.rs b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/reader.rs index 8796030c9..d9a58b64e 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/reader.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/reader.rs @@ -1,8 +1,11 @@ -use crate::matmul::components::stage::single_buffer; +use crate::matmul::components::{ + stage::{shared::CommonStageConfig, ReaderFamily}, + tile::TileConfig, +}; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::{stage::Stage, tile, Ident}; +use crate::matmul::components::{stage::Stage, Ident}; #[derive(CubeType)] pub struct LhsBufferReader { @@ -16,14 +19,25 @@ pub struct RhsBufferReader { pub buffer: u32, } +pub struct LhsBufferReaderFamily; +pub struct RhsBufferReaderFamily; + +impl ReaderFamily for LhsBufferReaderFamily { + type Reader = LhsBufferReader; +} + +impl ReaderFamily for RhsBufferReaderFamily { + type Reader = RhsBufferReader; +} + #[cube] impl LhsBufferReader { - pub fn read_tile( + pub fn read_tile( this: &Self, compute_plane_offset: u32, - #[comptime] config: single_buffer::Config, + #[comptime] config: CommonStageConfig, ) -> Slice> { - this.stage.get_tile::>( + this.stage.get_tile::>( compute_plane_offset, this.buffer, Ident::Lhs, @@ -34,12 +48,12 @@ impl LhsBufferReader { #[cube] impl RhsBufferReader { - pub fn read_tile( + pub fn read_tile( this: &Self, accumulator_offset: u32, - #[comptime] config: single_buffer::Config, + #[comptime] config: CommonStageConfig, ) -> Slice> { - this.stage.get_tile::>( + this.stage.get_tile::>( this.buffer, accumulator_offset, Ident::Rhs, diff --git a/crates/cubecl-linalg/src/matmul/components/stage/staging.rs b/crates/cubecl-linalg/src/matmul/components/stage/staging.rs index 447da298d..783ae682d 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/staging.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/staging.rs @@ -1,7 +1,7 @@ use crate::matmul::components::stage::tiling_order::{ ColMajorTiling, RowMajorTiling, TilingOrderConfig, }; -use crate::matmul::components::stage::{Config, TilingOrder}; +use crate::matmul::components::stage::{StageConfig, TilingOrder}; use crate::matmul::components::Ident; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -16,7 +16,7 @@ pub struct Stage { #[cube] impl Stage { /// Instantiate a new stage for the given identifier - pub fn new(#[comptime] ident: Ident, #[comptime] config: S) -> Stage { + pub fn new(#[comptime] ident: Ident, #[comptime] config: S) -> Stage { let line_size = config.line_size(ident); let smem = SharedMemory::new_lined( @@ -28,7 +28,7 @@ impl Stage { } /// Get the tile at position (x,y) regardless of layout, as a contiguous slice - pub fn get_tile( + pub fn get_tile( &self, x: u32, y: u32, diff --git a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs index 8790ed0ee..b069a294a 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs @@ -1,112 +1,129 @@ use crate::matmul::components::config::MatmulConfig; -use crate::matmul::components::tile::base::Matmul as _; -use crate::matmul::components::tile::Config as TileConfig; +use crate::matmul::components::tile::{TileConfig, TileMatmul, TileMatmulFamily}; use crate::matmul::components::{ - as_cmma_layout, tile, Ident, MatmulKernel, MatmulProblem, MatrixLayout, + as_cmma_layout, Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatmulProblem, + MatmulSize, MatrixLayout, }; use crate::matmul::kernels::matmul::AdvancedConfig; use crate::matmul::kernels::MatmulAvailabilityError; +use cubecl_core::ir::{Elem, FloatKind}; use cubecl_core::{self as cubecl, Feature}; use cubecl_core::{cmma, prelude::*}; -use std::marker::PhantomData; -macro_rules! instruction { - ($name:ident, $m:expr, $n:expr, $k:expr) => { - pub struct $name { - _input: PhantomData, - _output: PhantomData, - } +pub struct Accelerated; - #[cube] - impl tile::Matmul for $name { - const M: u32 = $m; - const N: u32 = $n; - const K: u32 = $k; - - type Lhs = cmma::Matrix; - type Rhs = cmma::Matrix; - type Accumulator = cmma::Matrix; - - fn execute( - lhs: &Self::Lhs, - rhs: &Self::Rhs, - out: &mut Self::Accumulator, - #[comptime] _config: Config, - ) { - execute::(lhs, rhs, out); - } - - fn init_lhs(#[comptime] config: Config) -> Self::Lhs { - init_lhs(config.layout(Ident::Lhs), Self::M, Self::N, Self::K) - } - - fn init_rhs(#[comptime] config: Config) -> Self::Rhs { - init_rhs(config.layout(Ident::Rhs), Self::M, Self::N, Self::K) - } - - fn fill_lhs(slice: &Slice>, lhs: &mut Self::Lhs, #[comptime] config: Config) { - fill_lhs(slice, lhs, config, Self::M, Self::K); - } - - fn fill_rhs(slice: &Slice>, rhs: &mut Self::Rhs, #[comptime] config: Config) { - fill_rhs(slice, rhs, config, Self::N, Self::K); - } - - fn fill_accumulator( - slice: &Slice>, - acc: &mut Self::Accumulator, - stride: u32, - #[comptime] config: Config, - ) { - fill_accumulator(slice, acc, stride, config); - } - - fn read_accumulator( - out: &Self::Accumulator, - slice: &mut SliceMut>, - #[comptime] _config: Config, - ) { - read_accumulator::(out, slice, Self::N); - } - - fn init_accumulator(#[comptime] _config: Self::Config) -> Self::Accumulator { - init_output(Self::M, Self::N, Self::K) - } - - fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] _config: Self::Config) { - cmma::fill(&acc, O::from_int(0)); - } - } +impl TileMatmulFamily for Accelerated { + type Matmul = Accelerated; - impl MatmulKernel for $name { - type Config = Config; - - fn check_config(config: Self::Config) { - comptime!(check_plane_dim(config.plane_dim())); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - check_availability::(Self::M, Self::N, Self::K, client) - } - - fn make_config( - problem: &MatmulProblem, - cube_dim: &CubeDim, - cube_count: &CubeCount, - advanced_config: &AdvancedConfig, - ) -> Self::Config { - make_config(problem, cube_dim, cube_count, advanced_config) - } - } - }; + fn size(config: &Self::Config) -> MatmulSize { + config.size + } + + fn input(tile_size: MatmulSize) -> Self::Input { + tile_size + } + + fn requires_tensor_cores() -> bool { + true + } } -instruction!(Accelerated16x16x8, 16, 16, 8); -instruction!(Accelerated16x16x16, 16, 16, 16); -instruction!(Accelerated32x8x16, 32, 8, 16); -instruction!(Accelerated8x32x16, 8, 32, 16); +#[cube] +impl TileMatmul for Accelerated { + type Config = Config; + type Lhs = cmma::Matrix; + type Rhs = cmma::Matrix; + type Accumulator = cmma::Matrix; + + fn execute( + lhs: &Self::Lhs, + rhs: &Self::Rhs, + out: &mut Self::Accumulator, + #[comptime] _config: Config, + ) { + execute::(lhs, rhs, out); + } + + fn init_lhs(#[comptime] config: Config) -> Self::Lhs { + init_lhs( + config.layout(Ident::Lhs), + config.size.m, + config.size.n, + config.size.k, + ) + } + + fn init_rhs(#[comptime] config: Config) -> Self::Rhs { + init_rhs( + config.layout(Ident::Rhs), + config.size.m, + config.size.n, + config.size.k, + ) + } + + fn fill_lhs(slice: &Slice>, lhs: &mut Self::Lhs, #[comptime] config: Config) { + fill_lhs(slice, lhs, config, config.size.m, config.size.k); + } + + fn fill_rhs(slice: &Slice>, rhs: &mut Self::Rhs, #[comptime] config: Config) { + fill_rhs(slice, rhs, config, config.size.n, config.size.k); + } + + fn fill_accumulator( + slice: &Slice>, + acc: &mut Self::Accumulator, + stride: u32, + #[comptime] config: Config, + ) { + fill_accumulator(slice, acc, stride, config); + } + + fn read_accumulator( + out: &Self::Accumulator, + slice: &mut SliceMut>, + #[comptime] config: Config, + ) { + read_accumulator::(out, slice, config.size.n); + } + + fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { + init_output(config.size.m, config.size.n, config.size.k) + } + + fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] _config: Self::Config) { + cmma::fill(acc, O::from_int(0)); + } +} + +impl MatmulConfigFactory for Accelerated { + type Input = MatmulSize; + type Config = Config; + + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + if config.plane_dim != 32 { + return Err(Box::new("Error: Expected plane dimension to be 32, but found {}. Please ensure that cube dimension x is set correctly.")); + } + Ok(()) + } + + fn check_availability( + client: &ComputeClient, + config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + check_availability::(config.size.m, config.size.n, config.size.k, client) + } + + fn make_config( + input: Self::Input, + problem: &MatmulProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + make_config(input, problem, cube_dim, cube_count, advanced_config) + } +} #[cube] fn execute( @@ -221,46 +238,50 @@ fn check_availability( k: u32, client: &ComputeClient, ) -> Result<(), MatmulAvailabilityError> { + let i_elem = I::as_elem_native().expect("to be a native type"); + let o_elem = O::as_elem_native().expect("to be a native type"); + + let i_elem = match i_elem { + Elem::Float(FloatKind::Flex32) => Elem::Float(FloatKind::F32), + _ => i_elem, + }; + + let o_elem = match o_elem { + Elem::Float(FloatKind::Flex32) => Elem::Float(FloatKind::F32), + _ => o_elem, + }; + if !client.properties().feature_enabled(Feature::Cmma { - a: I::as_elem(), - b: I::as_elem(), - c: O::as_elem(), + a: i_elem, + b: i_elem, + c: o_elem, m: m as u8, k: k as u8, n: n as u8, }) { return Err(MatmulAvailabilityError::CmmaInstructionUnavailable { - input: I::as_elem(), - output: O::as_elem(), + input: i_elem, + output: o_elem, m, n, k, }); } - if !(client - .properties() - .feature_enabled(Feature::Type(I::as_elem())) - && client - .properties() - .feature_enabled(Feature::Type(O::as_elem()))) + if !(client.properties().feature_enabled(Feature::Type(i_elem)) + && client.properties().feature_enabled(Feature::Type(o_elem))) { return Err(MatmulAvailabilityError::TypesUnavailable { - input: I::as_elem(), - output: O::as_elem(), + input: i_elem, + output: o_elem, }); } Ok(()) } -fn check_plane_dim(actual_plane_dim: u32) { - assert_eq!(32, actual_plane_dim, "Error: Expected plane dimension to be 32, but found {}. Please ensure that cube dimension x is set correctly.", - actual_plane_dim - ); -} - fn make_config( + input: MatmulSize, problem: &MatmulProblem, cube_dim: &CubeDim, _cube_count: &CubeCount, @@ -277,6 +298,7 @@ fn make_config( }; Config::new( + input, cube_dim.x, lhs_tile_layout, rhs_tile_layout, @@ -289,6 +311,7 @@ fn make_config( #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for Accelerated instruction pub struct Config { + size: MatmulSize, plane_dim: u32, lhs_layout: MatrixLayout, rhs_layout: MatrixLayout, @@ -297,7 +320,7 @@ pub struct Config { out_line_size: u32, } -impl tile::Config for Config { +impl TileConfig for Config { fn plane_dim(&self) -> u32 { self.plane_dim } @@ -317,12 +340,17 @@ impl tile::Config for Config { Ident::Out => self.out_line_size, } } + + fn size(&self) -> &MatmulSize { + &self.size + } } impl MatmulConfig for Config {} impl Config { pub fn new( + size: MatmulSize, plane_dim: u32, lhs_layout: MatrixLayout, rhs_layout: MatrixLayout, @@ -331,6 +359,7 @@ impl Config { out_line_size: u32, ) -> Self { Self { + size, plane_dim, lhs_layout, rhs_layout, diff --git a/crates/cubecl-linalg/src/matmul/components/tile/base.rs b/crates/cubecl-linalg/src/matmul/components/tile/base.rs index f502cd661..67032a9f1 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/base.rs @@ -1,7 +1,17 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::components::{config::MatmulConfig, Ident, MatmulKernel, MatrixLayout}; +use crate::matmul::components::{ + config::MatmulConfig, Ident, MatmulConfigFactory, MatmulSize, MatrixLayout, +}; + +pub trait TileMatmulFamily: MatmulConfigFactory { + fn size(config: &Self::Config) -> MatmulSize; + fn input(tile_size: MatmulSize) -> Self::Input; + fn requires_tensor_cores() -> bool; + + type Matmul: TileMatmul; +} #[cube] /// Provides matrix multiplication operations at the tile level. @@ -16,16 +26,8 @@ use crate::matmul::components::{config::MatmulConfig, Ident, MatmulKernel, Matri /// - Slices given as inputs must always be valid. If the actual matrix multiplication /// should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand. /// - Enough units are present to perform the whole computation -pub trait Matmul: - 'static + Send + Sync + MatmulKernel -{ - /// Number of rows of LHS - const M: u32; - /// Number of columns of RHS - const N: u32; - /// Common dimension of LHS and RHS - const K: u32; - +pub trait TileMatmul: 'static + Send + Sync { + type Config: TileConfig; /// Contains LHS data that can be split across the units type Lhs: CubeType; /// Contains RHS data that can be split across the units @@ -91,7 +93,7 @@ pub trait Matmul: } /// Configuration for the Tile matmul (TMM) level -pub trait Config: MatmulConfig { +pub trait TileConfig: MatmulConfig { /// Returns the size of the plane dimension fn plane_dim(&self) -> u32; @@ -100,4 +102,7 @@ pub trait Config: MatmulConfig { /// Returns the line size for the given ident fn line_size(&self, ident: Ident) -> u32; + + /// Returns the line size for the given ident + fn size(&self) -> &MatmulSize; } diff --git a/crates/cubecl-linalg/src/matmul/components/tile/mod.rs b/crates/cubecl-linalg/src/matmul/components/tile/mod.rs index b190cb3f6..95b3da5a6 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/mod.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/mod.rs @@ -1,4 +1,8 @@ pub mod accelerated; +#[cfg(any(test, feature = "export_tests"))] +/// Use plane operations to simulate tensor cores. +/// +/// Only use in testing, since it is very slow. pub mod plane; mod base; diff --git a/crates/cubecl-linalg/src/matmul/components/tile/plane.rs b/crates/cubecl-linalg/src/matmul/components/tile/plane.rs index 7a64b8c6d..ac18d17f7 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/plane.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/plane.rs @@ -1,19 +1,14 @@ use crate::matmul::components::config::MatmulConfig; -use crate::matmul::components::tile::Config as TileConfig; -use crate::matmul::components::MatmulProblem; -use crate::matmul::components::{tile, Ident, MatmulKernel, MatrixLayout}; +use crate::matmul::components::{ + tile, Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatrixLayout, +}; +use crate::matmul::components::{MatmulProblem, MatmulSize}; use crate::matmul::kernels::matmul::AdvancedConfig; use crate::matmul::kernels::MatmulAvailabilityError; + use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, Feature}; -use std::marker::PhantomData; - -pub type PlaneMma16x16x16 = PlaneMma; -pub type PlaneMma16x16x8 = PlaneMma; -pub type PlaneMma16x16x32 = PlaneMma; -pub type PlaneMma32x8x16 = PlaneMma; -pub type PlaneMma8x32x16 = PlaneMma; -pub type PlaneMma32x32x32 = PlaneMma; +use tile::{TileConfig, TileMatmul, TileMatmulFamily}; /// PlaneMMA instruction uses plane cooperation but does not rely on tensor cores /// @@ -24,19 +19,27 @@ pub type PlaneMma32x32x32 = PlaneMma; /// - When loading perpendicular to the lines, too much data is loaded from in comparison to what is used. /// A solution could be to always load the stage with lhs in row-major and rhs in col-major, using only parallel fill /// - If not vec4, there are patches in read_output that may harm performance -pub struct PlaneMma { - _input: PhantomData, - _output: PhantomData, +pub struct PlaneMma; + +impl TileMatmulFamily for PlaneMma { + type Matmul = Self; + + fn size(config: &Self::Config) -> MatmulSize { + config.size + } + + fn input(tile_size: MatmulSize) -> Self::Input { + tile_size + } + + fn requires_tensor_cores() -> bool { + false + } } #[cube] -impl tile::Matmul - for PlaneMma -{ - const M: u32 = M; - const N: u32 = N; - const K: u32 = K; - +impl TileMatmul for PlaneMma { + type Config = Config; type Lhs = Array; type Rhs = Array; type Accumulator = Array; @@ -47,11 +50,11 @@ impl tile::Mat out: &mut Self::Accumulator, #[comptime] config: Config, ) { - let k_jump = config.plane_dim() / Self::N; - let row_division = config.plane_dim() / Self::M; + let k_jump = config.plane_dim() / config.size.n; + let row_division = config.plane_dim() / config.size.m; - let num_jumps = Self::K / k_jump; - let compute_width = Self::N / row_division; + let num_jumps = config.size.k / k_jump; + let compute_width = config.size.n / row_division; let unit_offset = UNIT_POS_X % row_division * compute_width; @@ -65,7 +68,7 @@ impl tile::Mat #[unroll] for n_iter in 0..compute_width { - let unit_to_read = k_inner * Self::N + n_iter + unit_offset; + let unit_to_read = k_inner * config.size.n + n_iter + unit_offset; let b_kn = plane_broadcast::(b_kp, unit_to_read); out[n_iter] += O::cast_from(a_pk * b_kn); } @@ -73,12 +76,12 @@ impl tile::Mat } } - fn init_lhs(#[comptime] _config: Config) -> Self::Lhs { - Array::new(Self::K) + fn init_lhs(#[comptime] config: Config) -> Self::Lhs { + Array::new(config.size.k) } fn init_rhs(#[comptime] config: Config) -> Self::Rhs { - Array::new(Self::K * Self::N / config.plane_dim()) + Array::new(config.size.k * config.size.n / config.plane_dim()) } fn fill_lhs(slice: &Slice>, lhs: &mut Self::Lhs, #[comptime] config: Config) { @@ -87,8 +90,8 @@ impl tile::Mat slice, &mut lhs.to_slice_mut(), UNIT_POS_X, - Self::M, - Self::K, + config.size.m, + config.size.k, config.line_size(Ident::Lhs), config.plane_dim(), ), @@ -96,8 +99,8 @@ impl tile::Mat slice, &mut lhs.to_slice_mut(), UNIT_POS_X, - Self::M, - Self::K, + config.size.m, + config.size.k, config.line_size(Ident::Lhs), config.plane_dim(), ), @@ -110,8 +113,8 @@ impl tile::Mat slice, &mut rhs.to_slice_mut(), UNIT_POS_X, - Self::N, - Self::K, + config.size.n, + config.size.k, config.line_size(Ident::Rhs), config.plane_dim(), ), @@ -119,8 +122,8 @@ impl tile::Mat slice, &mut rhs.to_slice_mut(), UNIT_POS_X, - Self::N, - Self::K, + config.size.n, + config.size.k, config.line_size(Ident::Rhs), config.plane_dim(), ), @@ -134,7 +137,7 @@ impl tile::Mat #[comptime] config: Config, ) { let unit = UNIT_POS_X; - let n = Self::N; + let n = config.size.n; let line_size = config.line_size(Ident::Out); let plane_dim = config.plane_dim(); @@ -148,7 +151,7 @@ impl tile::Mat let row_jump = plane_dim / n; #[unroll] - for m_iter in 0..Self::M / row_jump { + for m_iter in 0..config.size.m / row_jump { let m_row = row_jump * m_iter + m_row_alt; let offset = m_row * num_lines + col_idx; let line = slice[offset]; @@ -169,14 +172,14 @@ impl tile::Mat let line_size = config.line_size(Ident::Out); let plane_dim = config.plane_dim(); - let row_division = plane_dim / Self::M; - let compute_width = Self::N / row_division; + let row_division = plane_dim / config.size.m; + let compute_width = config.size.n / row_division; let num_lines = compute_width / line_size; let unit = UNIT_POS_X; let row = unit / row_division; - let row_offset = row * Self::N / line_size; + let row_offset = row * config.size.n / line_size; let offset = row_offset + unit % row_division * num_lines; @@ -218,12 +221,12 @@ impl tile::Mat } fn init_accumulator(#[comptime] config: Config) -> Self::Accumulator { - let len = Self::M * Self::N / (config.plane_dim()); + let len = config.size.m * config.size.n / (config.plane_dim()); Array::new(len) } fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config) { - let len = Self::M * Self::N / (config.plane_dim()); + let len = config.size.m * config.size.n / (config.plane_dim()); #[unroll] for i in 0..len { @@ -349,41 +352,24 @@ fn fill_parallel_rhs( } } -impl MatmulKernel - for PlaneMma -{ +impl MatmulConfigFactory for PlaneMma { type Config = Config; + type Input = MatmulSize; - fn check_config(config: Self::Config) { - let plane_dim = config.plane_dim(); - assert!(M * N % plane_dim == 0); - assert!(K * N % plane_dim == 0); - } - - fn check_availability( - client: &ComputeClient, - ) -> Result<(), MatmulAvailabilityError> { - if !client.properties().feature_enabled(Feature::Plane) { - return Err(MatmulAvailabilityError::PlaneOperationsUnavailable); + fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> { + if config.size.m * config.size.n % config.plane_dim != 0 { + return Err(Box::new("Todo")); } - if !(client - .properties() - .feature_enabled(Feature::Type(I::as_elem())) - && client - .properties() - .feature_enabled(Feature::Type(O::as_elem()))) - { - return Err(MatmulAvailabilityError::TypesUnavailable { - input: I::as_elem(), - output: O::as_elem(), - }); + if config.size.k * config.size.n % config.plane_dim != 0 { + return Err(Box::new("Todo")); } Ok(()) } fn make_config( + input: Self::Input, problem: &MatmulProblem, cube_dim: &CubeDim, _cube_count: &CubeCount, @@ -400,6 +386,7 @@ impl MatmulKer }; Config::new( + input, cube_dim.x, lhs_tile_layout, rhs_tile_layout, @@ -408,11 +395,35 @@ impl MatmulKer problem.out_line_size as u32, ) } + + fn check_availability( + client: &ComputeClient, + _config: &Self::Config, + ) -> Result<(), MatmulAvailabilityError> { + let i_elem = MP::EG::as_elem_native_unchecked(); + let o_elem = MP::EG::as_elem_native_unchecked(); + + if !client.properties().feature_enabled(Feature::Plane) { + return Err(MatmulAvailabilityError::PlaneOperationsUnavailable); + } + + if !(client.properties().feature_enabled(Feature::Type(i_elem)) + && client.properties().feature_enabled(Feature::Type(o_elem))) + { + return Err(MatmulAvailabilityError::TypesUnavailable { + input: i_elem, + output: o_elem, + }); + } + + Ok(()) + } } #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for PlaneMma instruction pub struct Config { + size: MatmulSize, plane_dim: u32, lhs_layout: MatrixLayout, rhs_layout: MatrixLayout, @@ -421,7 +432,7 @@ pub struct Config { out_line_size: u32, } -impl tile::Config for Config { +impl TileConfig for Config { fn plane_dim(&self) -> u32 { self.plane_dim } @@ -441,12 +452,17 @@ impl tile::Config for Config { Ident::Out => self.out_line_size, } } + + fn size(&self) -> &MatmulSize { + &self.size + } } impl MatmulConfig for Config {} impl Config { pub fn new( + size: MatmulSize, plane_dim: u32, lhs_layout: MatrixLayout, rhs_layout: MatrixLayout, @@ -455,6 +471,7 @@ impl Config { out_line_size: u32, ) -> Self { Self { + size, plane_dim, lhs_layout, rhs_layout, diff --git a/crates/cubecl-linalg/src/matmul/kernels/error.rs b/crates/cubecl-linalg/src/matmul/kernels/error.rs index 2611b77ee..54b90236b 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/error.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/error.rs @@ -1,9 +1,12 @@ use cubecl_core::ir::Elem; use std::fmt::Debug; +use crate::matmul::components::InvalidConfigError; + pub enum MatmulLaunchError { Unavailable(MatmulAvailabilityError), InvalidProblem(MatmulInvalidProblem), + InvalidConfig(InvalidConfigError), } pub enum MatmulAvailabilityError { @@ -46,6 +49,12 @@ impl From for MatmulLaunchError { } } +impl From for MatmulLaunchError { + fn from(value: InvalidConfigError) -> Self { + Self::InvalidConfig(value) + } +} + impl Debug for MatmulLaunchError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -63,6 +72,13 @@ impl Debug for MatmulLaunchError { err ) } + MatmulLaunchError::InvalidConfig(err) => { + writeln!( + f, + "Unable to launch matmul because the config is invalid: {:?}", + err.to_string() + ) + } } } } diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/base.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/base.rs index f8aa7a914..3d792ec6a 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/base.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/base.rs @@ -1,59 +1,42 @@ use cubecl_core::prelude::*; -use crate::matmul::components::stage::{self}; -use crate::matmul::components::{batch, global, tile, MatmulSpec}; -use crate::matmul::components::{MatmulKernel, MatmulProblem}; +use crate::matmul::components::stage::{self, CommonStageInput}; +use crate::matmul::components::{batch, global, tile, MatmulPrecision}; +use crate::matmul::components::{MatmulConfigFactory, MatmulProblem}; use crate::matmul::kernels::matmul::AdvancedConfig; -use crate::matmul::kernels::{MatmulAvailabilityError, MatmulInvalidProblem}; - -type LhsStageReader = <>::LhsLoader as global::Loader< - EG, - ES, - ::Config, ->>::StageReader; -type RhsStageReader = <>::RhsLoader as global::Loader< - EG, - ES, - ::Config, ->>::StageReader; - -type EG = ::EG; -type ES = ::ES; +use crate::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; /// Specifications for a matmul algorithm -pub trait Algorithm { - type TileMatmul: tile::Matmul + MatmulKernel; - - type StageMatmul: stage::Matmul< - MS::ES, - MS::EG, - MS::EA, - LhsReader = LhsStageReader, - RhsReader = RhsStageReader, - > + MatmulKernel; - - type GlobalMatmul: global::Matmul; +pub trait Algorithm { + type TileMatmul: tile::TileMatmulFamily; + type StageMatmul: stage::StageMatmulFamily>; + type GlobalMatmul: global::GlobalMatmulFamily; + type BatchMatmul: batch::BatchMatmulFamily>; + type Selection; - type BatchMatmul: batch::Matmul + MatmulKernel; + fn cube_dim(selection: &Self::Selection) -> CubeDim; + fn cube_count(selection: &Self::Selection, problem: &MatmulProblem) -> CubeCount; - fn cube_dim() -> CubeDim; - fn cube_count(problem: &MatmulProblem) -> CubeCount; #[allow(clippy::type_complexity)] fn make_config( + input: ::Input, problem: &MatmulProblem, cube_dim: &CubeDim, cube_count: &CubeCount, advanced_config: &AdvancedConfig, - ) -> Result<::Config, MatmulInvalidProblem> { - let config = Self::BatchMatmul::make_config(problem, cube_dim, cube_count, advanced_config); + ) -> Result<::Config, MatmulLaunchError> { + let config = + Self::BatchMatmul::make_config(input, problem, cube_dim, cube_count, advanced_config); problem.check_config(&config)?; + Self::BatchMatmul::check_config(&config)?; Ok(config) } - fn check_availability( + fn check_availability( client: &ComputeClient, + config: &::Config, ) -> Result<(), MatmulAvailabilityError> { - Self::BatchMatmul::check_availability::(client) + Self::BatchMatmul::check_availability::(client, config) } fn advanced_config() -> AdvancedConfig { diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs index 63d22d707..0aa8be46c 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs @@ -1,6 +1,8 @@ mod base; mod selector; +pub mod pipelined; +pub mod specialized; pub mod standard; pub use base::Algorithm; diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/pipelined.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/pipelined.rs new file mode 100644 index 000000000..614a8f7ff --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/pipelined.rs @@ -0,0 +1,49 @@ +use cubecl_core::prelude::*; +use std::marker::PhantomData; + +use crate::matmul::components::batch::{CubeCountDispatch, CubeDispatch}; +use crate::matmul::components::stage::{self}; +use crate::matmul::components::MatmulProblem; +use crate::matmul::components::{batch, global}; +use crate::matmul::components::{tile, MatmulSelection}; + +use super::base; + +pub struct PipelinedAlgorithm { + pub _tmm: PhantomData, + pub _dispatch: PhantomData, +} + +impl base::Algorithm for PipelinedAlgorithm +where + TMM: tile::TileMatmulFamily, + Dispatch: CubeDispatch + CubeCountDispatch, +{ + type TileMatmul = TMM; + type StageMatmul = stage::single_buffer::SingleBufferMatmulFamily; + type GlobalMatmul = global::buffered::pipelined::PipelinedMatmulFamily; + + type BatchMatmul = batch::one_to_one::OneToOneMatmulFamily; + type Selection = MatmulSelection; + + fn cube_dim(selection: &MatmulSelection) -> CubeDim { + CubeDim::new(selection.plane_dim, selection.num_stagess.m, 1) + } + + fn cube_count(selection: &MatmulSelection, problem: &MatmulProblem) -> CubeCount { + let m_stage = selection.num_stagess.m * selection.tile.m; + let n_stage = selection.num_stagess.n * selection.tile.n; + let cubes_for_m = (problem.m as u32 + m_stage - 1) / m_stage; + let cubes_for_n = (problem.n as u32 + n_stage - 1) / n_stage; + + Dispatch::cube_count(cubes_for_m, cubes_for_n, problem.num_batches() as u32) + } + + fn advanced_config() -> crate::matmul::kernels::matmul::AdvancedConfig { + crate::matmul::kernels::matmul::AdvancedConfig { + lhs_tiling_order: stage::TilingOrderConfig::ColMajor, + rhs_tiling_order: stage::TilingOrderConfig::RowMajor, + enforced_tile_layout: (None, None), + } + } +} diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selector.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selector.rs index 80186bc9b..e1d45981a 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selector.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selector.rs @@ -1,148 +1,131 @@ +use std::marker::PhantomData; + use cubecl_core::{client::ComputeClient, ir::Elem, prelude::CubePrimitive, Feature, Runtime}; use cubecl_runtime::DeviceProperties; use crate::matmul::{ components::{ - stage::*, - tile::{ - accelerated::{ - Accelerated16x16x16, Accelerated16x16x8, Accelerated32x8x16, Accelerated8x32x16, - }, - plane::{PlaneMma16x16x16, PlaneMma32x8x16, PlaneMma8x32x16}, - }, - InputRuntimeArg, MatmulProblem, MatmulSpec, OutputRuntimeArg, + batch::TransposedDispatch, stage::*, tile::TileMatmulFamily, InputRuntimeArg, + MatmulProblem, MatmulSelection, MatmulSize, MatmulSpec, OutputRuntimeArg, }, kernels::{matmul::base::matmul_cube_preparation, MatmulLaunchError}, }; -use super::standard::StandardAlgorithm; +use super::{ + pipelined::PipelinedAlgorithm, specialized::SpecializedAlgorithm, standard::StandardAlgorithm, +}; const NUM_SM_APPROX: usize = 50; const NUM_TENSOR_CORES_APPROX: usize = 8; -pub struct CmmaSelector; +pub trait MatmulSelector { + fn select_kernel<'a, MS: MatmulSpec, R: Runtime>( + client: &ComputeClient, + input: InputRuntimeArg<'a, MS, R>, + output: OutputRuntimeArg<'a, MS, R>, + problem: MatmulProblem, + plane_dim: u32, + ) -> Result<(), MatmulLaunchError>; + fn stage_tf32_supported() -> bool; +} + +pub struct StandardSelector { + _tmm: PhantomData, + _dispatch: PhantomData, +} -impl CmmaSelector { - pub fn select_kernel<'a, MS: MatmulSpec, R: Runtime>( +pub struct PipelinedSelector { + _tmm: PhantomData, +} + +pub struct SpecializedSelector { + _tmm: PhantomData, +} + +impl MatmulSelector for StandardSelector { + fn select_kernel<'a, MS: MatmulSpec, R: Runtime>( client: &ComputeClient, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, problem: MatmulProblem, + plane_dim: u32, ) -> Result<(), MatmulLaunchError> { - let (instruction_m, instruction_n, instruction_k) = find_instruction_shape( - Some(( - client.properties(), - (MS::ES::as_elem(), MS::ES::as_elem(), MS::EA::as_elem()), - )), - problem.m, - problem.n, - ); + let selection = matmul_selection::(client, &problem, plane_dim); + let config_input = CommonStageInput { + tile: TMM::input(selection.tile), + num_stages: selection.num_stagess, + }; + + matmul_cube_preparation::>( + client, + input, + output, + problem, + config_input, + selection, + ) + } - let stage_size_m_n = find_stage_size_m_n( - problem.m, - problem.n, - problem.num_batches(), - NUM_SM_APPROX, - NUM_TENSOR_CORES_APPROX, - instruction_m, - instruction_n, - ); + fn stage_tf32_supported() -> bool { + TMM::requires_tensor_cores() + } +} - match (instruction_m, instruction_n, instruction_k) { - (16, 16, 8) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - (16, 16, 16) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - (32, 8, 16) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - (8, 32, 16) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - _ => panic!("No configuration found for instruction shapes."), - } +impl MatmulSelector for PipelinedSelector { + fn select_kernel<'a, MS: MatmulSpec, R: Runtime>( + client: &ComputeClient, + input: InputRuntimeArg<'a, MS, R>, + output: OutputRuntimeArg<'a, MS, R>, + problem: MatmulProblem, + plane_dim: u32, + ) -> Result<(), MatmulLaunchError> { + let selection = matmul_selection::(client, &problem, plane_dim); + let config_input = CommonStageInput { + tile: TMM::input(selection.tile), + num_stages: selection.num_stagess, + }; + + matmul_cube_preparation::>( + client, + input, + output, + problem, + config_input, + selection, + ) + } + + fn stage_tf32_supported() -> bool { + TMM::requires_tensor_cores() + } +} + +impl MatmulSelector for SpecializedSelector { + fn select_kernel<'a, MS: MatmulSpec, R: Runtime>( + client: &ComputeClient, + input: InputRuntimeArg<'a, MS, R>, + output: OutputRuntimeArg<'a, MS, R>, + problem: MatmulProblem, + plane_dim: u32, + ) -> Result<(), MatmulLaunchError> { + let selection = matmul_selection::(client, &problem, plane_dim); + let config_input = CommonStageInput { + tile: TMM::input(selection.tile), + num_stages: selection.num_stagess, + }; + + matmul_cube_preparation::>( + client, + input, + output, + problem, + config_input, + selection, + ) + } + + fn stage_tf32_supported() -> bool { + TMM::requires_tensor_cores() } } @@ -223,99 +206,49 @@ fn find_stage_size_m_n( } } -pub struct PlaneMmaSelector; - -impl PlaneMmaSelector { - pub fn select_kernel<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient, - input: InputRuntimeArg<'a, MS, R>, - output: OutputRuntimeArg<'a, MS, R>, - problem: MatmulProblem, - ) -> Result<(), MatmulLaunchError> { - let (instruction_m, instruction_n, instruction_k) = - find_instruction_shape(None, problem.m, problem.n); - - let stage_size_m_n = find_stage_size_m_n( - problem.m, - problem.n, - problem.num_batches(), - NUM_SM_APPROX, - NUM_TENSOR_CORES_APPROX, - instruction_m, - instruction_n, - ); - - match (instruction_m, instruction_n, instruction_k) { - (16, 16, 16) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - (32, 8, 16) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - (8, 32, 16) => match stage_size_m_n { - 1 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 2 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 4 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - 8 => matmul_cube_preparation::< - MS, - R, - StandardAlgorithm>, - >(client, input, output, problem), - _ => panic!("No configuration found for this stage size. "), - }, - _ => panic!("No configuration found for instruction shapes."), - } +fn matmul_selection( + client: &ComputeClient, + problem: &MatmulProblem, + plane_dim: u32, +) -> MatmulSelection { + let (instruction_m, instruction_n, instruction_k) = find_instruction_shape( + if TMM::requires_tensor_cores() { + Some(( + client.properties(), + ( + MS::ES::as_elem_native_unchecked(), + MS::ES::as_elem_native_unchecked(), + MS::EA::as_elem_native_unchecked(), + ), + )) + } else { + None + }, + problem.m, + problem.n, + ); + + let stage_size_m_n = find_stage_size_m_n( + problem.m, + problem.n, + problem.num_batches(), + NUM_SM_APPROX, + NUM_TENSOR_CORES_APPROX, + instruction_m, + instruction_n, + ); + + MatmulSelection { + tile: MatmulSize { + m: instruction_m as u32, + n: instruction_n as u32, + k: instruction_k as u32, + }, + num_stagess: MatmulSize { + m: stage_size_m_n as u32, + n: stage_size_m_n as u32, + k: 2, + }, + plane_dim, } } diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/specialized.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/specialized.rs new file mode 100644 index 000000000..c3ecce509 --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/specialized.rs @@ -0,0 +1,53 @@ +use cubecl_core::prelude::*; +use std::marker::PhantomData; + +use crate::matmul::components::batch::{CubeCountDispatch, CubeDispatch}; +use crate::matmul::components::stage::{self}; +use crate::matmul::components::MatmulProblem; +use crate::matmul::components::{batch, global}; +use crate::matmul::components::{tile, MatmulSelection}; + +use super::base; + +pub struct SpecializedAlgorithm { + pub _tmm: PhantomData, + pub _dispatch: PhantomData, +} + +impl base::Algorithm for SpecializedAlgorithm +where + TMM: tile::TileMatmulFamily, + Dispatch: CubeDispatch + CubeCountDispatch, +{ + type TileMatmul = TMM; + type StageMatmul = stage::single_buffer::SingleBufferMatmulFamily; + type GlobalMatmul = global::buffered::specialized::SpecializedMatmulFamily; + + type BatchMatmul = batch::one_to_one::OneToOneMatmulFamily; + type Selection = MatmulSelection; + + fn cube_dim(selection: &MatmulSelection) -> CubeDim { + CubeDim::new( + selection.plane_dim, + selection.num_stagess.m + core::cmp::max(1u32, selection.num_stagess.m / 2), + 1, + ) + } + + fn cube_count(selection: &MatmulSelection, problem: &MatmulProblem) -> CubeCount { + let m_stage = selection.num_stagess.m * selection.tile.m; + let n_stage = selection.num_stagess.n * selection.tile.n; + let cubes_for_m = (problem.m as u32 + m_stage - 1) / m_stage; + let cubes_for_n = (problem.n as u32 + n_stage - 1) / n_stage; + + Dispatch::cube_count(cubes_for_m, cubes_for_n, problem.num_batches() as u32) + } + + fn advanced_config() -> crate::matmul::kernels::matmul::AdvancedConfig { + crate::matmul::kernels::matmul::AdvancedConfig { + lhs_tiling_order: stage::TilingOrderConfig::ColMajor, + rhs_tiling_order: stage::TilingOrderConfig::RowMajor, + enforced_tile_layout: (None, None), + } + } +} diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs index f59f70ade..4f32c2f8b 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs @@ -1,41 +1,39 @@ -use std::marker::PhantomData; - +use super::base; use cubecl_core::prelude::*; +use std::marker::PhantomData; -use crate::matmul::components::batch::CubeCountDispatch; +use crate::matmul::components::batch::{CubeCountDispatch, CubeDispatch}; use crate::matmul::components::global::full_load::CyclicLoading; -use crate::matmul::components::stage::{self, StageSize}; +use crate::matmul::components::stage::{self}; use crate::matmul::components::MatmulProblem; use crate::matmul::components::{batch, global}; -use crate::matmul::components::{tile, MatmulSpec}; - -use super::base; - -type Dispatch = batch::SwizzleTransposedDispatch<2>; +use crate::matmul::components::{tile, MatmulSelection}; -pub struct StandardAlgorithm { - pub _ms: PhantomData, - pub _stage: PhantomData, +pub struct StandardAlgorithm { pub _tmm: PhantomData, + pub _dispatch: PhantomData, } -impl> base::Algorithm - for StandardAlgorithm +impl base::Algorithm for StandardAlgorithm +where + TMM: tile::TileMatmulFamily, + Dispatch: CubeDispatch + CubeCountDispatch, { type TileMatmul = TMM; - type StageMatmul = stage::multi_buffer::Matmul; + type StageMatmul = stage::multi_buffer::MultiBufferMatmulFamily; type GlobalMatmul = - global::full_load::Matmul; + global::full_load::FullLoadMatmulFamily; - type BatchMatmul = batch::one_to_one::Matmul; + type BatchMatmul = batch::one_to_one::OneToOneMatmulFamily; + type Selection = MatmulSelection; - fn cube_dim() -> CubeDim { - CubeDim::new(MS::PLANE_DIM, Stage::NUM_M, 1) + fn cube_dim(selection: &MatmulSelection) -> CubeDim { + CubeDim::new(selection.plane_dim, selection.num_stagess.m, 1) } - fn cube_count(problem: &MatmulProblem) -> CubeCount { - let m_stage = Stage::NUM_M * Self::TileMatmul::M; - let n_stage = Stage::NUM_N * Self::TileMatmul::N; + fn cube_count(selection: &MatmulSelection, problem: &MatmulProblem) -> CubeCount { + let m_stage = selection.num_stagess.m * selection.tile.m; + let n_stage = selection.num_stagess.n * selection.tile.n; let cubes_for_m = (problem.m as u32 + m_stage - 1) / m_stage; let cubes_for_n = (problem.n as u32 + n_stage - 1) / n_stage; diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs index 2c3900957..fc0172e22 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs @@ -8,12 +8,13 @@ use cubecl_core::{ use crate::matmul; use crate::matmul::components::global::args::TensorInputsLaunch; use crate::matmul::components::{ - InputRuntimeArg, MatmulLaunch, MatmulProblem, MatmulSpec, OutputRuntimeArg, SingleMatmulSpec, + InputRuntimeArg, MatmulConfigFactory, MatmulLaunch, MatmulProblem, MatmulSpec, + OutputRuntimeArg, SingleMatmulSpec, }; use crate::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; use crate::tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle}; -use super::algorithm::{CmmaSelector, PlaneMmaSelector}; +use super::algorithm::MatmulSelector; use super::config::AdvancedConfig; use super::Algorithm; @@ -21,20 +22,13 @@ use super::Algorithm; /// /// Cmma will be used if enabled /// Will fail if unavailable -pub fn launch( +pub fn launch( client: &ComputeClient, lhs: TensorHandle, rhs: TensorHandle, out: TensorHandle, - disable_cmma: bool, ) -> Result, MatmulLaunchError> { - let result = launch_ref::( - client, - &lhs.as_ref(), - &rhs.as_ref(), - &out.as_ref(), - disable_cmma, - ); + let result = launch_ref::(client, &lhs.as_ref(), &rhs.as_ref(), &out.as_ref()); match result { Ok(_) => Ok(out), @@ -46,12 +40,11 @@ pub fn launch( /// /// Cmma will be used if available and enabled, /// otherwise it will fall back on a non-cmma implementation -pub fn launch_ref( +pub fn launch_ref( client: &ComputeClient, lhs: &TensorHandleRef<'_, R>, rhs: &TensorHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, - disable_cmma: bool, ) -> Result<(), MatmulLaunchError> { let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { MatrixLayout::Contiguous => (false, false), @@ -66,69 +59,65 @@ pub fn launch_ref( let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs); match (lhs_make_contiguous, rhs_make_contiguous) { - (false, false) => matmul_cmma_ref_no_check::( + (false, false) => matmul_cmma_ref_no_check::( client, lhs, rhs, out, (lhs_transposed, rhs_transposed), - disable_cmma, ), - (false, true) => matmul_cmma_ref_no_check::( + (false, true) => matmul_cmma_ref_no_check::( client, lhs, &into_contiguous::(client, rhs).as_ref(), out, (lhs_transposed, rhs_transposed), - disable_cmma, ), - (true, false) => matmul_cmma_ref_no_check::( + (true, false) => matmul_cmma_ref_no_check::( client, &into_contiguous::(client, lhs).as_ref(), rhs, out, (lhs_transposed, rhs_transposed), - disable_cmma, ), - (true, true) => matmul_cmma_ref_no_check::( + (true, true) => matmul_cmma_ref_no_check::( client, &into_contiguous::(client, lhs).as_ref(), &into_contiguous::(client, rhs).as_ref(), out, (lhs_transposed, rhs_transposed), - disable_cmma, ), } } -fn matmul_cmma_ref_no_check( +fn matmul_cmma_ref_no_check( client: &ComputeClient, lhs: &TensorHandleRef<'_, R>, rhs: &TensorHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, transposed: (bool, bool), - disable_cmma: bool, ) -> Result<(), MatmulLaunchError> { let rank = lhs.strides.len(); + let eg_elem = EG::as_elem_native().expect("To be a native type"); let m = lhs.shape[rank - 2] as u32; let k = lhs.shape[rank - 1] as u32; let n = rhs.shape[rank - 1] as u32; let lhs_line_size = tensor_line_size_parallel( - R::line_size_elem(&EG::as_elem()), + R::line_size_elem(&eg_elem), lhs.shape, lhs.strides, rank - 1, ); let rhs_line_size = tensor_line_size_parallel( - R::line_size_elem(&EG::as_elem()), + R::line_size_elem(&eg_elem), rhs.shape, rhs.strides, rank - 1, ); let out_line_size = tensor_line_size_parallel( - R::line_size_elem(&EG::as_elem()), + R::line_size_elem(&eg_elem), out.shape, out.strides, rank - 1, @@ -160,45 +149,44 @@ fn matmul_cmma_ref_no_check( .hardware_properties() .defined_plane_size(); - match plane_size { - Some(32) => matmul_launch_kernel::<32, R, EG>( - client, - lhs, - rhs, - out, - disable_cmma, - (lhs_line_size, rhs_line_size, out_line_size), - problem, - ), - Some(64) => matmul_launch_kernel::<64, R, EG>( - client, - lhs, - rhs, - out, - disable_cmma, - (lhs_line_size, rhs_line_size, out_line_size), - problem, - ), - Some(plane_dim) => Err(MatmulLaunchError::Unavailable( - MatmulAvailabilityError::PlaneDimUnsupported { plane_dim }, - )), - None => Err(MatmulLaunchError::Unavailable( - MatmulAvailabilityError::PlaneDimUnknown, - )), - } + let plane_dim = match plane_size { + Some(32) | Some(64) => plane_size.expect("32 or 64"), + Some(plane_dim) => { + return Err(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::PlaneDimUnsupported { plane_dim }, + )) + } + None => { + return Err(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::PlaneDimUnknown, + )) + } + }; + + matmul_launch_kernel::( + client, + lhs, + rhs, + out, + (lhs_line_size, rhs_line_size, out_line_size), + problem, + plane_dim, + ) } -fn matmul_launch_kernel( +fn matmul_launch_kernel( client: &ComputeClient, lhs: &TensorHandleRef<'_, R>, rhs: &TensorHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, - disable_cmma: bool, (lhs_line_size, rhs_line_size, out_line_size): (u8, u8, u8), problem: MatmulProblem, + plane_dim: u32, ) -> Result<(), MatmulLaunchError> { - if disable_cmma { - PlaneMmaSelector::select_kernel::, R>( + if TypeId::of::() == TypeId::of::() + || TypeId::of::() == TypeId::of::() + { + S::select_kernel::, R>( client, TensorInputsLaunch::new( lhs.as_tensor_arg(lhs_line_size), @@ -206,11 +194,10 @@ fn matmul_launch_kernel( ), out.as_tensor_arg(out_line_size), problem, + plane_dim, ) - } else if TypeId::of::() == TypeId::of::() - || TypeId::of::() == TypeId::of::() - { - CmmaSelector::select_kernel::, R>( + } else if TypeId::of::() == TypeId::of::() { + S::select_kernel::, R>( client, TensorInputsLaunch::new( lhs.as_tensor_arg(lhs_line_size), @@ -218,9 +205,10 @@ fn matmul_launch_kernel( ), out.as_tensor_arg(out_line_size), problem, + plane_dim, ) - } else if TypeId::of::() == TypeId::of::() { - CmmaSelector::select_kernel::, R>( + } else if S::stage_tf32_supported() { + S::select_kernel::, R>( client, TensorInputsLaunch::new( lhs.as_tensor_arg(lhs_line_size), @@ -228,9 +216,10 @@ fn matmul_launch_kernel( ), out.as_tensor_arg(out_line_size), problem, + plane_dim, ) } else { - CmmaSelector::select_kernel::, R>( + S::select_kernel::, R>( client, TensorInputsLaunch::new( lhs.as_tensor_arg(lhs_line_size), @@ -238,20 +227,21 @@ fn matmul_launch_kernel( ), out.as_tensor_arg(out_line_size), problem, + plane_dim, ) } } -pub(crate) fn matmul_cube_preparation<'a, MS: MatmulSpec, R: Runtime, D: Algorithm>( +pub(crate) fn matmul_cube_preparation<'a, MS: MatmulSpec, R: Runtime, D: Algorithm>( client: &ComputeClient, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, problem: MatmulProblem, + config_input: ::Input, + selection: D::Selection, ) -> Result<(), MatmulLaunchError> { - D::check_availability::(client)?; - - let cube_dim = D::cube_dim(); - let cube_count = D::cube_count(&problem); + let cube_dim = D::cube_dim(&selection); + let cube_count = D::cube_count(&selection, &problem); let advanced_config = D::advanced_config(); launch_matmul::( @@ -262,11 +252,12 @@ pub(crate) fn matmul_cube_preparation<'a, MS: MatmulSpec, R: Runtime, D: Algorit cube_dim, cube_count, advanced_config, + config_input, ) } #[allow(clippy::too_many_arguments)] -fn launch_matmul<'a, MS: MatmulSpec, R: Runtime, D: Algorithm>( +fn launch_matmul<'a, MS: MatmulSpec, R: Runtime, D: Algorithm>( client: &ComputeClient, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, @@ -274,11 +265,21 @@ fn launch_matmul<'a, MS: MatmulSpec, R: Runtime, D: Algorithm>( cube_dim: CubeDim, cube_count: CubeCount, advanced_config: AdvancedConfig, + config_input: ::Input, ) -> Result<(), MatmulLaunchError> { - let config = D::make_config(&problem, &cube_dim, &cube_count, &advanced_config)?; + let config = D::make_config( + config_input, + &problem, + &cube_dim, + &cube_count, + &advanced_config, + )?; + D::check_availability::(client, &config)?; unsafe { - D::BatchMatmul::launch_unchecked::(client, cube_dim, cube_count, input, output, config); + D::BatchMatmul::launch_unchecked::( + client, cube_dim, cube_count, input, output, config, + ); }; Ok(()) diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/mod.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/mod.rs index f646cfbb8..5cb20fb3c 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/mod.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/mod.rs @@ -3,6 +3,6 @@ mod config; mod algorithm; -pub use algorithm::{standard, Algorithm, CmmaSelector, PlaneMmaSelector}; +pub use algorithm::*; pub use base::{launch, launch_ref}; pub use config::{create_stage_dim, AdvancedConfig}; diff --git a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/base.rs b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/base.rs index f3a6dc7e0..569819e61 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/base.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/base.rs @@ -18,6 +18,7 @@ pub fn tiling2d_cube_kernel( let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config); let offsets = calculate_batch_offsets::(lhs, rhs, out, CUBE_POS_Z); let shared_memories = make_shared_memories::(config); + block_loop::( lhs, rhs, diff --git a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/launch.rs index d16a98dc9..d3df15fcd 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/launch.rs @@ -34,7 +34,7 @@ pub fn matmul_tiling_2d_ref( config: Tiling2dConfig, ) { assert!( - F::as_elem().size() * config.block_size_k * max(config.block_size_m, config.block_size_n) + F::size().unwrap() * config.block_size_k * max(config.block_size_m, config.block_size_n) <= ::max_shared_memory_size(), "Shared memory limit will be busted. " ); diff --git a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/loader.rs b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/loader.rs index 5562ad46d..647ec6a41 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/loader.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/loader.rs @@ -175,7 +175,7 @@ pub(crate) fn load_plain>( let mut sm = load_info.shared_memory; if write_row < sm_dim_vertical { - if line_size == tile_size { + if comment![line_size == tile_size] { L::load_tile_plain::( tensor, &mut sm, diff --git a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/writer.rs b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/writer.rs index 25373d8d5..d27e2eea8 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/writer.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/tiling2d/tile/writer.rs @@ -39,7 +39,7 @@ impl OutputWriter for TileWriter { skip_col: coordinates.skip_col, }; - if line_size == tile_size { + if comptime![line_size == tile_size] { B::write_output::( out, results, diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs index 4b97d79d4..0519d967b 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs @@ -2,17 +2,22 @@ use std::fmt::Display; use cubecl_core::prelude::*; use cubecl_core::server::Handle; +use cubecl_core::tensor_line_size_parallel; use cubecl_core::CubeElement; use cubecl_core::Feature; use crate::matmul::components::global::args::TensorInputsLaunch; +use crate::matmul::components::tile::accelerated::Accelerated; +use crate::matmul::components::tile::plane::PlaneMma; use crate::matmul::components::Ident; +use crate::matmul::components::MatmulConfigFactory; use crate::matmul::components::MatmulLaunch; use crate::matmul::components::MatmulProblem; use crate::matmul::components::MatrixLayout; use crate::matmul::components::SingleMatmulSpec; use crate::matmul::kernels::matmul; use crate::matmul::kernels::matmul::Algorithm; +use crate::matmul::kernels::matmul::StandardSelector; use crate::matmul::tests::test_utils::CastInto; use crate::tensor::TensorHandle; @@ -27,35 +32,83 @@ struct TensorRawParts { original_data: Option>, } -type Spec = SingleMatmulSpec<32, EG, ES, f32>; +type Spec = SingleMatmulSpec; /// Test the correctness of the specified Matmul on the given device, /// against a naive CPU implementation over the given problem -pub fn test_matmul_algorithm(problem: MatmulProblem, device: &R::Device) -where - A: Algorithm>, +pub fn test_matmul_algorithm( + client: ComputeClient, + mut problem: MatmulProblem, + input: ::Input, + selection: A::Selection, +) where + A: Algorithm, EG: Float + CubeElement + Display + CastInto, ES: Float + CubeElement + Display + CastInto, R: Runtime, { - let client: ComputeClient<::Server, ::Channel> = R::client(device); - - if A::check_availability::(&client).is_err() { - // Can't execute the test. - println!("Skipped - not supported!"); - return; - } + let env = std::env::var("MATMUL_TEST_MODE"); + let panic_on_launch_err = match env { + Ok(val) => match val.as_str() { + "panic" => true, + "skip" => false, + _ => false, + }, + Err(_) => false, + }; let lhs = tensor_raw_parts::(&client, &problem, Ident::Lhs); let rhs = tensor_raw_parts::(&client, &problem, Ident::Rhs); let out = tensor_raw_parts::(&client, &problem, Ident::Out); - let cube_dim = A::cube_dim(); - let cube_count = A::cube_count(&problem); - let config = A::make_config(&problem, &cube_dim, &cube_count, &A::advanced_config()).unwrap(); + problem.lhs_line_size = tensor_line_size_parallel( + R::line_size_elem(&EG::as_elem_native_unchecked()), + &lhs.shape, + &lhs.strides, + lhs.strides.len() - 1, + ); + problem.rhs_line_size = tensor_line_size_parallel( + R::line_size_elem(&EG::as_elem_native_unchecked()), + &rhs.shape, + &rhs.strides, + rhs.strides.len() - 1, + ); + problem.out_line_size = tensor_line_size_parallel( + R::line_size_elem(&EG::as_elem_native_unchecked()), + &out.shape, + &out.strides, + out.strides.len() - 1, + ); + + let cube_dim = A::cube_dim(&selection); + let cube_count = A::cube_count(&selection, &problem); + let config = match A::make_config( + input, + &problem, + &cube_dim, + &cube_count, + &A::advanced_config(), + ) { + Ok(config) => config, + Err(err) => { + let msg = format!("Can't launch the test: {:?}", err); + if panic_on_launch_err { + panic!("{msg}"); + } else { + println!("{msg}"); + return; + } + } + }; + + if A::check_availability::(&client, &config).is_err() { + // Can't execute the test. + println!("Skipped - not supported!"); + return; + } unsafe { - A::BatchMatmul::launch_unchecked( + A::BatchMatmul::launch_unchecked::, R>( &client, cube_dim, cube_count, @@ -102,9 +155,9 @@ pub fn test_matmul_launch, R: R let client: ComputeClient<::Server, ::Channel> = R::client(device); if !(client.properties().feature_enabled(Feature::Plane) - && client - .properties() - .feature_enabled(Feature::Type(EG::as_elem()))) + && client.properties().feature_enabled(Feature::Type( + EG::as_elem_native().expect("To be a native type"), + ))) { // Can't execute the test. return; @@ -118,15 +171,17 @@ pub fn test_matmul_launch, R: R let rhs_handle = TensorHandle::new(rhs.shape, rhs.strides, rhs.handle); let out_handle = TensorHandle::new(out.shape, out.strides, out.handle); - let out = matmul::launch::( + let out = matmul::launch::>( &client, lhs_handle.clone(), rhs_handle.clone(), out_handle.clone(), - false, ) .unwrap_or_else(|_| { - matmul::launch::(&client, lhs_handle, rhs_handle, out_handle, true).unwrap() + matmul::launch::>( + &client, lhs_handle, rhs_handle, out_handle, + ) + .unwrap() }); assert_result::( @@ -181,7 +236,10 @@ fn tensor_raw_parts( } } Ident::Out => { - let handle = client.empty(tensor_size(problem, Ident::Out) * EG::as_elem().size()); + let handle = client.empty( + tensor_size(problem, Ident::Out) + * EG::as_elem_native().expect("To be a native type").size(), + ); let shape = shape(problem, Ident::Out); let strides = strides(problem, Ident::Out); @@ -223,9 +281,9 @@ fn assert_result< Some(epsilon) => epsilon, None => { let maybe_cmma = client.properties().feature_enabled(Feature::Cmma { - a: ES::as_elem(), - b: ES::as_elem(), - c: EG::as_elem(), + a: ES::as_elem_native().expect("To be a native type"), + b: ES::as_elem_native().expect("To be a native type"), + c: EG::as_elem_native().expect("To be a native type"), m: 16, k: 16, n: 16, diff --git a/crates/cubecl-linalg/src/matmul/tests/mod.rs b/crates/cubecl-linalg/src/matmul/tests/mod.rs index 6bc2f174e..f23a0004f 100644 --- a/crates/cubecl-linalg/src/matmul/tests/mod.rs +++ b/crates/cubecl-linalg/src/matmul/tests/mod.rs @@ -5,3 +5,5 @@ pub mod simple; mod test_macros; mod test_utils; pub mod tiling2d; + +pub use test_macros::cmma::suite::*; diff --git a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs deleted file mode 100644 index 55569169e..000000000 --- a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs +++ /dev/null @@ -1,3841 +0,0 @@ -// Tests nomenclature: -// batch: b[o=one_to_one, m=one_to_many][batch dims, optional] -// global: g[fl=full_load, bp=buffer/pipelined, bs=buffer/specialized][m]x[n]x[k], with m,n,k the whole matrix dimensions -// stage: s[m]x[n]x[k], with m,n,k the number of tiles along those dims -// tile: t[m]x[n]x[k], with m,n,k the tile dimensions. tile algorithm is given by macro arguments -// layouts: [r/c][r/c], r=row, c=col, respectively for lhs and rhs -// line size: ln[v], with v the line size of all tensors. if different then ln[v_lhs]x[v_rhs]x[v_out] -// Other specifications may be appended at the end - -#[allow(missing_docs)] -#[macro_export] -macro_rules! matmul_test_define { - ( - $t_16x16x16:ident, - $t_32x8x16:ident, - $t_8x32x16:ident - ) => { - #[test] - pub fn bo4_gbp256x256x256_s4x4x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 256, - n: 256, - k: 256, - batches: (vec![4], vec![4]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::pipelined::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(4, 4, 4) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbp16x16x256_s1x1x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 256, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::pipelined::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbp16x16x32_s1x1x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::pipelined::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbs128x256x256_s4x4x2_t16x16x16_cc_ln4_transposed_cube_count() { - let problem = MatmulProblem { - m: 128, - n: 256, - k: 256, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(problem: &MatmulProblem) -> CubeCount { - let m_stage = S4x4x2::NUM_M * 16; - let n_stage = S4x4x2::NUM_N * 16; - let cubes_needed_m = (problem.m as u32).div_ceil(m_stage); - let cubes_needed_n = (problem.n as u32).div_ceil(n_stage); - - use cubecl_linalg::matmul::components::batch::CubeCountDispatch; - batch::TransposedDispatch::cube_count(cubes_needed_m, cubes_needed_n, 1u32) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo4_4x3_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![1, 4], vec![4, 3]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 16) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbs16x16x480_s1x1x3_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 480, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::specialized::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gfl1000x16x16_s1x1x1_t16x16x16_rr_ln4_transposed_dispatch() { - let problem = MatmulProblem { - m: 1024, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 64, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3x4_gbs300x300x300_s4x4x2_t16x16x16_cc_ln4() { - let problem = MatmulProblem { - m: 300, - n: 300, - k: 300, - batches: (vec![3, 4], vec![3, 4]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::specialized::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 8, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(5, 5, 12) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbs16x32x32_s1x2x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::specialized::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbs32x16x32_s2x1x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::specialized::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 3, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbs16x16x128_s1x1x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 128, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::specialized::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo1_gbs16x16x32_s1x1x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::single_buffer::Matmul; - type GlobalMatmul = global::buffered::specialized::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::RowMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm1_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::RowMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm1_gfl32x16x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::RowMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm1_gfl16x32x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::RowMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm1_gfl16x16x32_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::RowMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm6_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![2, 3], vec![2, 3]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::RowMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm2_gfl32x32x32_s1x1x1_t16x16x16_rr_ln4_colspan() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 16, - batches: (vec![2], vec![2]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::ColMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm2_gfl32x32x32_s1x1x1_t16x16x16_rr_ln4_swizzlespan() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 16, - batches: (vec![2], vec![2]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::SwizzleSpanMatmul<2>, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(2, 2, 2) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm2_gfl32x32x32_s1x1x1_t16x16x16_rr_ln4_transposed_dispatch() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 16, - batches: (vec![2], vec![2]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::SwizzleSpanMatmul<2>, - batch::TransposedDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(2, 2, 2) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm2_gfl160x256x16_s1x1x1_t16x16x16_rr_ln4_swizzle_x_dispatch() { - let problem = MatmulProblem { - m: 160, - n: 256, - k: 16, - batches: (vec![2], vec![2]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::SwizzleSpanMatmul<2>, - batch::SwizzleNaturalDispatch<2>, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(10, 16, 2) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm2_gfl160x256x16_s1x1x1_t16x16x16_rr_ln4_swizzle_y_dispatch() { - let problem = MatmulProblem { - m: 160, - n: 256, - k: 16, - batches: (vec![2], vec![2]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::SwizzleSpanMatmul<2>, - batch::SwizzleTransposedDispatch<2>, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(16, 10, 2) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bm5_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4_cubez2() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![5], vec![5]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = batch::one_to_many::Matmul< - Spec, - Self::GlobalMatmul, - batch::RowMajorSpanMatmul, - batch::NaturalDispatch, - >; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 2) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3x4_gfl300x300x300_s4x4x2_t16x16x16_cc_ln4() { - let problem = MatmulProblem { - m: 300, - n: 300, - k: 300, - batches: (vec![3, 4], vec![3, 4]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - struct Test {} - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(5, 5, 12) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3x4_gfl108x108x243_s4x4x2_t16x16x16_cr_ln4() { - let problem = MatmulProblem { - m: 108, - n: 108, - k: 243, - batches: (vec![3, 4], vec![3, 4]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(2, 2, 12) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3x4_gfl256x256x256_s4x4x2_t16x16x16_cr_ln2x2x4() { - let problem = MatmulProblem { - m: 256, - n: 256, - k: 256, - batches: (vec![3, 4], vec![3, 4]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 2, - rhs_line_size: 2, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(4, 4, 12) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3_gfl256x256x256_s4x4x2_t16x16x16_rc_ln4() { - let problem = MatmulProblem { - m: 256, - n: 256, - k: 256, - batches: (vec![3], vec![3]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(4, 4, 3) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3_gfl16x16x16_s1x1x1_t16x16x16_cc_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![3], vec![3]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 3) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3_gfl16x16x16_s1x1x1_t16x16x16_R() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![3], vec![3]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 3) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl256x256x256_s4x4x2_t16x16x16_rc_ln4_col_major() { - let problem = MatmulProblem { - m: 256, - n: 256, - k: 256, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(4, 4, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::ColMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s1x1x1_t16x16x16_cc_ln4_col_major() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(2, 2, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::ColMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(2, 2, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x14x16_s1x1x1_t16x16x16_rc_ln4x4x2() { - let problem = MatmulProblem { - m: 16, - n: 14, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 2, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x12x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 12, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x12_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 12, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl60x60x120_s4x4x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 60, - n: 60, - k: 120, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x36_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 36, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl12x12x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 12, - n: 12, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rr_ln1() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rc_ln1() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_cc_ln1() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_cr_ln1() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rc_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4_lhs_col_enforced() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (Some(MatrixLayout::ColMajor), None), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rr_ln4_rhs_col_enforced() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (None, Some(MatrixLayout::ColMajor)), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rc_ln4_rhs_row_enforced() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (None, Some(MatrixLayout::RowMajor)), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_rr_ln4_lhs_col_enforced() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (Some(MatrixLayout::ColMajor), None), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_cr_ln4_lhs_row_enforced() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (Some(MatrixLayout::RowMajor), None), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_rr_ln4_rhs_col_enforced() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (None, Some(MatrixLayout::ColMajor)), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_rc_ln4_rhs_row_enforced() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (None, Some(MatrixLayout::RowMajor)), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_rr_ln4_lhs_col_enforced() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (Some(MatrixLayout::ColMajor), None), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_cr_ln4_lhs_row_enforced() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (Some(MatrixLayout::RowMajor), None), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_rr_ln4_rhs_col_enforced() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (None, Some(MatrixLayout::ColMajor)), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_rc_ln4_rhs_row_enforced() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: (None, Some(MatrixLayout::RowMajor)), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo3x4_gfl256x256x256_s4x4x2_t16x16x16_cr_ln2x2x4_lhs_row_rhs_col_enforced() { - let problem = MatmulProblem { - m: 256, - n: 256, - k: 256, - batches: (vec![3, 4], vec![3, 4]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 2, - rhs_line_size: 2, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(4, 4, 12) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - enforced_tile_layout: ( - Some(MatrixLayout::RowMajor), - Some(MatrixLayout::ColMajor), - ), - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_cr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_rc_ln1() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_cr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_cc_ln4() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_rr_ln4() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_rc_ln4() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_cr_ln4() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl8x32x16_s1x1x1_t8x32x16_cc_ln4() { - let problem = MatmulProblem { - m: 8, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_8x32x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_rr_ln2() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 2, - rhs_line_size: 2, - out_line_size: 2, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s2x2x2_t16x16x16_rr_ln4_col_major() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - - fn advanced_config() -> AdvancedConfig { - AdvancedConfig { - lhs_tiling_order: TilingOrderConfig::ColMajor, - rhs_tiling_order: TilingOrderConfig::ColMajor, - ..Default::default() - } - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x32_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x16_s1x1x1_t16x16x16_cc_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x16x128_s1x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 16, - k: 128, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x16x128_s2x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 16, - k: 128, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x224_s2x2x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 224, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl16x32x16_s1x2x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 16, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s2x2x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x16_s2x2x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s2x2x2_t16x16x16_rc_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s2x2x2_t16x16x16_cr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x32x32_s2x2x2_t16x16x16_cc_ln4() { - let problem = MatmulProblem { - m: 32, - n: 32, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x16x16_s2x1x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 32, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x8x16_s1x1x1_t32x8x16_cc_ln1() { - let problem = MatmulProblem { - m: 32, - n: 8, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::ColMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_32x8x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 1, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl128x16x16_s8x1x1_t16x16x16_rr_ln1() { - let problem = MatmulProblem { - m: 128, - n: 16, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 1, - rhs_line_size: 1, - out_line_size: 1, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 8, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl64x64x16_s4x4x1_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 64, - n: 64, - k: 16, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl64x64x32_s4x4x2_t16x16x16_rr_ln4() { - let problem = MatmulProblem { - m: 64, - n: 64, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x16x32_s2x1x2_t16x16x16_rr_ln4_tilewise_load_lhs() { - let problem = MatmulProblem { - m: 32, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::TilewiseLoading, - global::full_load::CyclicLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl32x16x32_s2x1x2_t16x16x16_rr_ln4_tilewise_load_rhs() { - let problem = MatmulProblem { - m: 32, - n: 16, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::TilewiseLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 2, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - - #[test] - pub fn bo_gfl64x64x32_s4x4x1_t16x16x16_rr_ln4_tilewise_load_both() { - let problem = MatmulProblem { - m: 64, - n: 64, - k: 32, - batches: (vec![], vec![]), - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - struct Test {} - - impl matmul::Algorithm for Test { - type TileMatmul = $t_16x16x16; - type StageMatmul = - stage::multi_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Spec, - Self::StageMatmul, - global::full_load::TilewiseLoading, - global::full_load::TilewiseLoading, - >; - type BatchMatmul = - batch::one_to_one::Matmul; - - fn cube_dim() -> CubeDim { - CubeDim::new(Spec::PLANE_DIM, 4, 1) - } - fn cube_count(_problem: &MatmulProblem) -> CubeCount { - CubeCount::Static(1, 1, 1) - } - } - - test_matmul_algorithm::( - problem, - &<::Device>::default(), - ); - } - }; -} diff --git a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_launch.rs b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_launch.rs deleted file mode 100644 index 21cb7dc09..000000000 --- a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_launch.rs +++ /dev/null @@ -1,25 +0,0 @@ -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_matmul_launch { - ($eg:ty) => { - use cubecl_linalg::matmul::tests::cmma_matmul::matmul_test_launcher::test_matmul_launch; - - #[test] - pub fn test_launch_matmul() { - type EG = $eg; - let problem = MatmulProblem { - m: 300, - n: 200, - k: 250, - batches: (vec![3, 4], vec![3, 4]), - lhs_layout: MatrixLayout::ColMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size: 4, - rhs_line_size: 4, - out_line_size: 4, - }; - - test_matmul_launch::(problem, &Default::default()); - } - }; -} diff --git a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/mod.rs b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/mod.rs index 573403dc6..68a7c1e33 100644 --- a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/mod.rs +++ b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/mod.rs @@ -1,123 +1,52 @@ #![allow(missing_docs)] -pub mod matmul_algorithm; -pub mod matmul_launch; +pub mod suite; #[macro_export] -macro_rules! testgen_matmul_cmma { - () => { - #[allow(non_snake_case)] - mod cmma_matmul { - $crate::testgen_matmul_cmma!(f32); - } - }; - - ($float:ident) => { - use super::*; - use cubecl_linalg::matmul::tests; - use cubecl_linalg::matmul::components::tile::accelerated::*; - use cubecl_core::prelude::*; - use cubecl_linalg::matmul::{ - components::{ - SingleMatmulSpec, - MatmulSpec, - batch, global, - stage::{self, *}, - tile::plane::PlaneMma16x16x16, - MatmulProblem, MatrixLayout, - }, - kernels::matmul::{self, AdvancedConfig}, - tests::cmma_matmul::matmul_test_launcher::test_matmul_algorithm, - }; - use cubecl_core::prelude::*; - - pub type FloatT = $float; - pub type Spec = SingleMatmulSpec<32, FloatT, half::f16, f32>; - pub type EG = FloatT; - pub type ES = half::f16; - pub type EA = f32; +macro_rules! testgen_matmul_accelerated { + ($eg:ty, $es:ty) => { + type EG = $eg; + type ES = $es; - cubecl_linalg::matmul_test_define!( - Accelerated16x16x16, - Accelerated32x8x16, - Accelerated8x32x16 - ); - - cubecl_linalg::testgen_matmul_launch!( - FloatT - ); + $crate::matmul_standard_tests!(); }; ([$($float:ident),*]) => { #[allow(non_snake_case)] - mod matmul_cmma { + mod matmul_accelerated { use super::*; + type TMM = $crate::matmul::components::tile::accelerated::Accelerated; + ::paste::paste! { $(mod [<$float _ty>] { use super::*; - - $crate::testgen_matmul_cmma!($float); + $crate::testgen_matmul_accelerated!($float, half::f16); })* } } }; } - #[macro_export] -macro_rules! testgen_matmul_plane_mma { - () => { - #[allow(non_snake_case)] - mod matmul_plane_mma{ - $crate::testgen_matmul_plane_mma!(f32, f32); - } +macro_rules! testgen_matmul_plane { + ($float:ident) => { + type EG = $float; + type ES = $float; + + $crate::matmul_standard_tests!(); }; - ($float:ident, $float_stage:ident) => { + ([$($float:ident),*]) => { + #[allow(non_snake_case)] + mod matmul_plane { use super::*; - use cubecl_linalg::matmul::tests; - use cubecl_linalg::matmul::components::tile::plane::*; - use cubecl_linalg::matmul::{ - components::{ - SingleMatmulSpec, - MatmulSpec, - batch, global, - stage::{self, *}, - tile::plane::PlaneMma16x16x16, - MatmulProblem, MatrixLayout, - }, - kernels::matmul::{self, AdvancedConfig}, - tests::cmma_matmul::matmul_test_launcher::test_matmul_algorithm, - }; - use cubecl_core::prelude::*; - - pub type FloatGlobal = $float; - pub type FloatStage = $float_stage; - pub type Spec = SingleMatmulSpec<32, FloatGlobal, FloatStage, f32>; - pub type EG = FloatGlobal; - pub type ES = FloatStage; - pub type EA = f32; - - cubecl_linalg::matmul_test_define!( - PlaneMma16x16x16, - PlaneMma32x8x16, - PlaneMma8x32x16 - ); + type TMM = $crate::matmul::components::tile::plane::PlaneMma; - cubecl_linalg::testgen_matmul_launch!( - FloatGlobal - ); - }; - - ([$($float:ident),*], $float_stage:ident) => { - ::paste::paste! { - $( - // Generate a unique module for each `float` type with the single `float_stage` - #[allow(non_snake_case)] - mod [] { + ::paste::paste! { + $(mod [<$float _ty>] { use super::*; - $crate::testgen_matmul_plane_mma!($float, $float_stage); - } - )* + $crate::testgen_matmul_accelerated!($float, $float); + })* + } } }; } diff --git a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/suite.rs b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/suite.rs new file mode 100644 index 000000000..03983801a --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/suite.rs @@ -0,0 +1,204 @@ +use std::fmt::Display; + +use crate::matmul::components::stage::CommonStageInput; +use crate::matmul::components::tile::TileMatmulFamily; +use crate::matmul::components::{MatmulProblem, MatrixLayout}; +use crate::matmul::components::{MatmulSelection, MatmulSize}; +use crate::matmul::kernels::matmul::Algorithm; +use crate::matmul::tests::cmma_matmul::matmul_test_launcher::test_matmul_algorithm; +use crate::matmul::tests::test_utils::CastInto; + +use cubecl_core::prelude::Float; +use cubecl_core::{CubeElement, Runtime}; + +pub fn test_algo, P: TestPrecision, R: Runtime>( + layouts: (MatrixLayout, MatrixLayout), + tile: MatmulSize, + stage: MatmulSize, + problem: MatmulSize, +) { + let client = R::client(&Default::default()); + let plane_dim = match client + .properties() + .hardware_properties() + .defined_plane_size() + { + Some(val) => val, + None => { + println!("Can't run test without a fixed plane size."); + return; + } + }; + + let problem = MatmulProblem { + m: problem.m as usize, + n: problem.n as usize, + k: problem.k as usize, + batches: (vec![2], vec![2]), + lhs_layout: layouts.0, + rhs_layout: layouts.1, + lhs_line_size: 1, // Will be changed + rhs_line_size: 1, // Will be changed + out_line_size: 1, // Will be changed + }; + + let selection = MatmulSelection { + tile, + num_stagess: stage, + plane_dim, + }; + let config_input = CommonStageInput { + tile: A::TileMatmul::input(selection.tile), + num_stages: selection.num_stagess, + }; + + test_matmul_algorithm::(client, problem, config_input, selection); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! matmul_standard_tests { + () => { + use $crate::matmul::components::{MatmulSize, MatrixLayout}; + + $crate::matmul_standard_tests!([RowMajor, ColMajor], [RowMajor, ColMajor]); + }; + + ([$($lhs_layout:ident),*], [$($rhs_layout:ident),*]) => { + $( + mod $lhs_layout { + use super::*; + mod $rhs_layout { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout); + } + } + )* + }; + + ($lhs_layout:ident, $rhs_layout:ident) => { + mod t16x16x16 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, MatmulSize { m: 16, n: 16, k: 16 }); + } + + mod t32x8x16 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, MatmulSize { m: 32, n: 8, k: 16 }); + } + + mod t8x32x16 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, MatmulSize { m: 8, n: 32, k: 16 }); + } + + mod t16x16x8 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, MatmulSize { m: 16, n: 16, k: 8 }); + } + }; + + ($lhs_layout:ident, $rhs_layout:ident, $tile:expr) => { + mod s1x1x1 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, MatmulSize { m: 1, n: 1, k: 1 }); + } + + mod s8x8x1 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, MatmulSize { m: 8, n: 8, k: 1 }); + } + + mod s2x2x2 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, MatmulSize { m: 2, n: 2, k: 2 }); + } + + mod s4x4x2 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, MatmulSize { m: 4, n: 4, k: 2 }); + } + }; + + ($lhs_layout:ident, $rhs_layout:ident, $tile:expr, $stage:expr) => { + mod p32x32x32 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, $stage, MatmulSize { m: 32, n: 32, k: 32 }); + } + + mod p64x32x32 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, $stage, MatmulSize { m: 64, n: 32, k: 32 }); + } + + mod p32x32x64 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, $stage, MatmulSize { m: 32, n: 32, k: 64 }); + } + + mod p100x100x100 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, $stage, MatmulSize { m: 100, n: 100, k: 100 }); + } + + mod p23x1x17 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, $stage, MatmulSize { m: 23, n: 1, k: 17 }); + } + + mod p256x256x256 { + use super::*; + $crate::matmul_standard_tests!($lhs_layout, $rhs_layout, $tile, $stage, MatmulSize { m: 256, n: 256, k: 256 }); + } + }; + + ($lhs_layout:ident, $rhs_layout:ident, $tile:expr, $stage:expr, $problem:expr) => { + use $crate::matmul::kernels::matmul::standard::StandardAlgorithm; + use $crate::matmul::kernels::matmul::specialized::SpecializedAlgorithm; + use $crate::matmul::kernels::matmul::pipelined::PipelinedAlgorithm; + + #[test] + pub fn standard() { + cubecl_linalg::matmul::tests::test_algo::, (EG, ES), TestRuntime>( + (MatrixLayout::$lhs_layout, MatrixLayout::$rhs_layout), + $tile, + $stage, + $problem, + ); + } + + #[test] + pub fn specialized() { + cubecl_linalg::matmul::tests::test_algo::, (EG, ES), TestRuntime>( + (MatrixLayout::$lhs_layout, MatrixLayout::$rhs_layout), + $tile, + $stage, + $problem, + ); + } + + #[test] + pub fn pipelined() { + cubecl_linalg::matmul::tests::test_algo::, (EG, ES), TestRuntime>( + (MatrixLayout::$lhs_layout, MatrixLayout::$rhs_layout), + $tile, + $stage, + $problem, + ); + } + }; +} + +pub trait TestPrecision { + type EG: Float + CubeElement + Display + CastInto; + type ES: Float + CubeElement + Display + CastInto; +} + +impl TestPrecision for (EG, ES) +where + EG: Float + CubeElement + Display + CastInto, + ES: Float + CubeElement + Display + CastInto, +{ + type EG = EG; + type ES = ES; +} diff --git a/crates/cubecl-linalg/src/matmul/tests/test_macros/mod.rs b/crates/cubecl-linalg/src/matmul/tests/test_macros/mod.rs index 00e8457da..79ef80a21 100644 --- a/crates/cubecl-linalg/src/matmul/tests/test_macros/mod.rs +++ b/crates/cubecl-linalg/src/matmul/tests/test_macros/mod.rs @@ -1,3 +1,3 @@ -mod cmma; +pub mod cmma; mod simple; mod tiling2d; diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index d970e1ee0..c50045633 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -92,7 +92,7 @@ where handle: &self.handle, strides: &self.strides, shape: &self.shape, - elem_size: E::as_elem().size(), + elem_size: E::size().expect("Should be a native type"), runtime: PhantomData, } } @@ -130,7 +130,7 @@ where { pub fn empty(client: &ComputeClient, shape: Vec) -> Self { let num_elements: usize = shape.iter().product(); - let size = E::as_elem().size(); + let size = E::size().expect("To be a native type"); let handle = client.empty(size * num_elements); let strides = Self::contiguous_strides(&shape); diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 956355c70..264799799 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -105,7 +105,7 @@ pub fn into_contiguous_prefetch( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim); - let handle = client.empty(num_elems * E::as_elem().size()); + let handle = client.empty(num_elems * E::size().expect("To be a native type")); let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle); into_contiguous_kernel::launch::, R>( diff --git a/crates/cubecl-linalg/src/tensor/virtual.rs b/crates/cubecl-linalg/src/tensor/virtual.rs index 4f29350ba..11e904ccf 100644 --- a/crates/cubecl-linalg/src/tensor/virtual.rs +++ b/crates/cubecl-linalg/src/tensor/virtual.rs @@ -1,4 +1,4 @@ -use cubecl::prelude::*; +use cubecl::prelude::{CubeContext, CubeType, *}; use cubecl_core::{self as cubecl, unexpanded}; use std::{marker::PhantomData, sync::Arc}; @@ -13,10 +13,13 @@ pub struct ReadWrite; /// Tensor representation that is decoupled from how the tensor is stored. #[derive(Clone)] pub struct VirtualTensor { - state: Arc>, + // state: Arc>, + _e: PhantomData, _p: PhantomData, } +impl Copy for VirtualTensor {} + /// Expand type for [VirtualTensor]. #[derive(Clone)] pub struct VirtualTensorExpand { @@ -24,37 +27,161 @@ pub struct VirtualTensorExpand { _p: PhantomData, } -#[cube] +#[allow(unused, clippy::all)] impl VirtualTensor { /// Read the tensor at the given index. pub fn read(&self, index: u32) -> Line { - self.state.read(index) + unexpanded!(); } - /// Get the shape of the tensor at the given axis. pub fn shape(&self, axis: u32) -> u32 { - self.state.shape(axis) + unexpanded!(); } - /// Get the stride of the tensor at the given axis. pub fn stride(&self, axis: u32) -> u32 { - self.state.stride(axis) + unexpanded!(); } - /// Get the rank of the tensor. pub fn rank(&self) -> u32 { - self.state.rank() + unexpanded!(); + } + pub fn __expand_read( + context: &mut CubeContext, + this: ::ExpandType, + index: ::ExpandType, + ) -> as CubeType>::ExpandType { + this.__expand_read_method(context, index) + } + pub fn __expand_shape( + context: &mut CubeContext, + this: ::ExpandType, + axis: ::ExpandType, + ) -> ::ExpandType { + this.__expand_shape_method(context, axis) + } + pub fn __expand_stride( + context: &mut CubeContext, + this: ::ExpandType, + axis: ::ExpandType, + ) -> ::ExpandType { + this.__expand_stride_method(context, axis) + } + pub fn __expand_rank( + context: &mut CubeContext, + this: ::ExpandType, + ) -> ::ExpandType { + this.__expand_rank_method(context) + } +} + +#[allow(unused, clippy::all)] +impl VirtualTensorExpand { + pub fn __expand_read_method( + self, + context: &mut CubeContext, + index: ::ExpandType, + ) -> as CubeType>::ExpandType { + let _arg_0 = index; + self.state + .clone() + .__expand_read_method(context, _arg_0.into()) + } + + pub fn __expand_shape_method( + self, + context: &mut CubeContext, + axis: ::ExpandType, + ) -> ::ExpandType { + let _arg_0 = axis; + self.state + .clone() + .__expand_shape_method(context, _arg_0.into()) + } + + pub fn __expand_stride_method( + self, + context: &mut CubeContext, + axis: ::ExpandType, + ) -> ::ExpandType { + let _arg_0 = axis; + self.state + .clone() + .__expand_stride_method(context, _arg_0.into()) + } + + pub fn __expand_rank_method(self, context: &mut CubeContext) -> ::ExpandType { + self.state.clone().__expand_rank_method(context) + } + + pub fn __expand_read( + context: &mut CubeContext, + this: Self, + index: ::ExpandType, + ) -> as CubeType>::ExpandType { + VirtualTensor::::__expand_read(context, this, index) + } + + pub fn __expand_shape( + context: &mut CubeContext, + this: Self, + axis: ::ExpandType, + ) -> ::ExpandType { + VirtualTensor::::__expand_shape(context, this, axis) + } + + pub fn __expand_stride( + context: &mut CubeContext, + this: Self, + axis: ::ExpandType, + ) -> ::ExpandType { + VirtualTensor::::__expand_stride(context, this, axis) + } + + pub fn __expand_rank(context: &mut CubeContext, this: Self) -> ::ExpandType { + VirtualTensor::::__expand_rank(context, this) } } -#[cube] +#[allow(unused, clippy::all)] impl VirtualTensor { - /// Write the tensor at the given index. + #[doc = " Write the tensor at the given index."] pub fn write(&mut self, index: u32, value: Line) { - self.state.write(index, value) + unexpanded!() + } + + pub fn __expand_write( + context: &mut CubeContext, + this: ::ExpandType, + index: ::ExpandType, + value: as CubeType>::ExpandType, + ) -> <() as CubeType>::ExpandType { + this.__expand_write_method(context, index, value) } } +impl VirtualTensorExpand { + pub fn __expand_write_method( + self, + context: &mut CubeContext, + index: ::ExpandType, + value: as CubeType>::ExpandType, + ) -> <() as CubeType>::ExpandType { + let _arg_0 = index; + let _arg_1 = value; + self.state + .clone() + .__expand_write_method(context, _arg_0, _arg_1) + } + + pub fn __expand_write( + context: &mut CubeContext, + this: Self, + index: ::ExpandType, + value: as CubeType>::ExpandType, + ) -> <() as CubeType>::ExpandType { + VirtualTensor::::__expand_write(context, this, index, value) + } +} impl VirtualTensor { /// Create a new [read only](Read) [virtual tensor](VirtualTensor). pub fn new + 'static>(_v: &V) -> Self { diff --git a/crates/cubecl-macros/src/generate/cube_type/generate_struct.rs b/crates/cubecl-macros/src/generate/cube_type/generate_struct.rs index 7c45c2cdf..794a7e854 100644 --- a/crates/cubecl-macros/src/generate/cube_type/generate_struct.rs +++ b/crates/cubecl-macros/src/generate/cube_type/generate_struct.rs @@ -184,6 +184,8 @@ impl CubeTypeStruct { }); quote! { + #[derive(serde::Serialize, serde::Deserialize)] + #[serde(bound(serialize = "", deserialize = ""))] #vis struct #name #generics { #(#fields),* } @@ -196,6 +198,8 @@ impl CubeTypeStruct { } } + impl #type_generics_names CompilationArg for #name #impl_generics #where_generics { } + impl #type_generics_names core::hash::Hash for #name #impl_generics #where_generics { fn hash(&self, state: &mut H) { #(#hash;)* diff --git a/crates/cubecl-macros/src/generate/expression.rs b/crates/cubecl-macros/src/generate/expression.rs index 919169121..ffd0058a2 100644 --- a/crates/cubecl-macros/src/generate/expression.rs +++ b/crates/cubecl-macros/src/generate/expression.rs @@ -31,11 +31,11 @@ impl Expression { let array = array.to_tokens(context); let index = index .as_const(context) - .map(|as_const| quote![#elem::from_lit(#as_const)]) + .map(|as_const| quote![#elem::from_lit(context, #as_const)]) .unwrap_or_else(|| index.to_tokens(context)); let right = right .as_const(context) - .map(|as_const| quote![#elem::from_lit(#as_const)]) + .map(|as_const| quote![#elem::from_lit(context, #as_const)]) .unwrap_or_else(|| right.to_tokens(context)); let op = format_ident!("{}", operator.array_op_name()); let expand = with_span( @@ -103,7 +103,7 @@ impl Expression { if var.is_const { let name = &var.name; let expand_elem = frontend_type("ExpandElementTyped"); - quote![#expand_elem::from_lit(#name)] + quote![#expand_elem::from_lit(context, #name)] } else { let name = &var.name; if var.try_consume(context) { @@ -121,7 +121,7 @@ impl Expression { } Expression::Literal { value, .. } => { let expand_elem = frontend_type("ExpandElementTyped"); - quote![#expand_elem::from_lit(#value)] + quote![#expand_elem::from_lit(context, #value)] } Expression::Assignment { left, right, .. } diff --git a/crates/cubecl-macros/src/generate/kernel.rs b/crates/cubecl-macros/src/generate/kernel.rs index 241c2e803..09486ff65 100644 --- a/crates/cubecl-macros/src/generate/kernel.rs +++ b/crates/cubecl-macros/src/generate/kernel.rs @@ -173,8 +173,10 @@ impl Launch { let mut define = quote! {}; let expand_fn = |ident, expand_name, ty| { + let ty = self.analysis.process_ty(&ty); + quote! { - let #ident = <#ty as #launch_arg_expand>::#expand_name(&self.#ident, &mut builder); + let #ident = <#ty as #launch_arg_expand>::#expand_name(&self.#ident.dynamic_cast(), &mut builder); } }; for input in self.runtime_inputs() { @@ -203,10 +205,10 @@ impl Launch { let runtime = prelude_type("Runtime"); let compiler = core_type("Compiler"); let io_map = self.io_mappings(); + let register_type = self.analysis.register_elems(); let runtime_args = self.runtime_params().map(|it| &it.name); let comptime_args = self.comptime_params().map(|it| &it.name); - let (_, generics, _) = self.func.sig.generics.split_for_impl(); - let generics = generics.as_turbofish(); + let generics = self.analysis.process_generics(&self.func.sig.generics); let allocator = self.args.local_allocator.as_ref(); let allocator = allocator.map(|it| it.to_token_stream()).unwrap_or_else( || quote![<<__R as #runtime>::Compiler as #compiler>::local_allocator()], @@ -214,6 +216,7 @@ impl Launch { quote! { let mut builder = #kernel_builder::with_local_allocator(#allocator); + #register_type #io_map expand #generics(&mut builder.context, #(#runtime_args.clone(),)* #(self.#comptime_args.clone()),*); builder.build(self.settings.clone()) diff --git a/crates/cubecl-macros/src/generate/statement.rs b/crates/cubecl-macros/src/generate/statement.rs index afa52e326..dabaf3bb4 100644 --- a/crates/cubecl-macros/src/generate/statement.rs +++ b/crates/cubecl-macros/src/generate/statement.rs @@ -17,7 +17,9 @@ impl Statement { init.as_ref().and_then(|it| it.as_const_primitive(context)) { let expand = frontend_type("ExpandElementTyped"); - Some(quote_spanned![as_const.span()=> #expand::from_lit(#as_const)]) + Some( + quote_spanned![as_const.span()=> #expand::from_lit(context, #as_const)], + ) } else if let Some(as_const) = init.as_ref().and_then(|it| it.as_const(context)) { Some(quote_spanned![as_const.span()=> #as_const.clone()]) diff --git a/crates/cubecl-macros/src/lib.rs b/crates/cubecl-macros/src/lib.rs index 0040c3892..de0cec7a7 100644 --- a/crates/cubecl-macros/src/lib.rs +++ b/crates/cubecl-macros/src/lib.rs @@ -107,6 +107,7 @@ fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result /// Derive macro to define a cube type that is launched with a kernel #[proc_macro_derive(CubeLaunch, attributes(expand, cube))] pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream { + // panic!("{gen}"); gen_cube_type(input, true) } diff --git a/crates/cubecl-macros/src/parse/expression.rs b/crates/cubecl-macros/src/parse/expression.rs index a44b46416..0e6835560 100644 --- a/crates/cubecl-macros/src/parse/expression.rs +++ b/crates/cubecl-macros/src/parse/expression.rs @@ -454,7 +454,9 @@ fn is_slice(index: &Expression) -> bool { fn fn_associated_type(path: &Expression) -> Option<(Path, PathSegment)> { // All supported primitives. Primitives don't start with an uppercase letter - const PRIMITIVES: &[&str] = &["bool", "i32", "i64", "u32", "f16", "bf16", "f32", "f64"]; + const PRIMITIVES: &[&str] = &[ + "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32", "f64", + ]; if !matches!(path, Expression::Path { .. }) { panic!("path: {path:?}"); } diff --git a/crates/cubecl-macros/src/parse/kernel.rs b/crates/cubecl-macros/src/parse/kernel.rs index e0f9b2a58..75ec31a07 100644 --- a/crates/cubecl-macros/src/parse/kernel.rs +++ b/crates/cubecl-macros/src/parse/kernel.rs @@ -1,8 +1,8 @@ use crate::{expression::Block, paths::prelude_type, scope::Context, statement::Pattern}; use darling::{ast::NestedMeta, util::Flag, FromMeta}; use proc_macro2::{Span, TokenStream}; -use quote::ToTokens; -use std::iter; +use quote::{quote, ToTokens}; +use std::{collections::HashMap, iter}; use syn::{ parse_quote, punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut, Expr, FnArg, Generics, Ident, ItemFn, ReturnType, Signature, TraitItemFn, Type, Visibility, @@ -30,12 +30,150 @@ impl KernelArgs { } } +pub struct GenericAnalysis { + pub map: HashMap, +} + +impl GenericAnalysis { + pub fn process_generics(&self, ty: &syn::Generics) -> TokenStream { + let mut output = quote![]; + + if ty.params.is_empty() { + return output; + } + + for param in ty.params.pairs() { + match param.value() { + syn::GenericParam::Type(type_param) => { + if let Some(ty) = self.map.get(&type_param.ident) { + output.extend(quote![#ty,]); + } else { + let ident = &type_param.ident; + output.extend(quote![#ident,]); + } + } + _ => todo!(), + } + } + + quote! { + ::<#output> + } + } + + pub fn register_elems(&self) -> TokenStream { + let mut output = quote![]; + + for (name, ty) in self.map.iter() { + output.extend(quote! { + builder + .context + .register_elem::<#ty>(#name::as_elem_native_unchecked()); + }); + } + + output + } + + pub fn process_ty(&self, ty: &syn::Type) -> syn::Type { + let type_path = match &ty { + Type::Path(type_path) => type_path, + _ => return ty.clone(), + }; + let path = &type_path.path; + + let mut returned = syn::Path { + leading_colon: path.leading_colon, + segments: syn::punctuated::Punctuated::new(), + }; + + for pair in path.segments.pairs() { + let segment = pair.value(); + let punc = pair.punct(); + + if let Some(segment) = self.map.get(&segment.ident) { + returned.segments.push_value(segment.clone()); + } else { + match &segment.arguments { + syn::PathArguments::AngleBracketed(arg) => { + let mut args = syn::punctuated::Punctuated::new(); + arg.args.iter().for_each(|arg| match arg { + syn::GenericArgument::Type(ty) => { + let ty = self.process_ty(ty); + args.push(syn::GenericArgument::Type(ty)); + } + _ => args.push_value(arg.clone()), + }); + + let segment = syn::PathSegment { + ident: segment.ident.clone(), + arguments: syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { + colon2_token: arg.colon2_token, + lt_token: arg.lt_token, + args, + gt_token: arg.gt_token, + }, + ), + }; + returned.segments.push_value(segment); + } + _ => returned.segments.push_value((*segment).clone()), + } + } + + if let Some(punc) = punc { + returned.segments.push_punct(**punc) + } + } + + syn::Type::Path(syn::TypePath { + qself: type_path.qself.clone(), + path: returned, + }) + } + + pub fn from_generics(generics: &syn::Generics) -> Self { + let mut map = HashMap::new(); + + for param in generics.params.pairs() { + if let syn::GenericParam::Type(type_param) = param.value() { + if let Some(syn::TypeParamBound::Trait(trait_bound)) = type_param.bounds.first() { + if let Some(bound) = trait_bound.path.get_ident() { + let name = bound.to_string(); + let index = map.len() as u8; + + match name.as_str() { + "Float" => { + map.insert( + type_param.ident.clone(), + parse_quote!(FloatExpand<#index>), + ); + } + "Numeric" => { + map.insert( + type_param.ident.clone(), + parse_quote!(NumericExpand<#index>), + ); + } + _ => {} + }; + } + } + }; + } + + Self { map } + } +} + pub struct Launch { pub args: KernelArgs, pub vis: Visibility, pub func: KernelFn, pub kernel_generics: Generics, pub launch_generics: Generics, + pub analysis: GenericAnalysis, } #[derive(Clone)] @@ -277,6 +415,7 @@ impl Launch { let mut expand_generics = kernel_generics.clone(); expand_generics.params = Punctuated::from_iter(iter::once(parse_quote!['kernel]).chain(expand_generics.params)); + let analysis = GenericAnalysis::from_generics(&func.sig.generics); Ok(Launch { args, @@ -284,6 +423,7 @@ impl Launch { func, kernel_generics, launch_generics: expand_generics, + analysis, }) } } diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index c00715d70..8d005db04 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -253,7 +253,7 @@ impl Optimizer { let step = range_loop.step.unwrap_or(1.into()); let i_id = match range_loop.i.kind { - VariableKind::Local { id, depth, .. } => (id, depth), + VariableKind::LocalMut { id, depth, .. } => (id, depth), _ => unreachable!(), }; let i = range_loop.i; diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index aebd47cf5..e956f44a6 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -336,6 +336,8 @@ impl Display for Instruction { OpId::BitwiseAnd => write!(f, "{} & {}", args[0], args[1]), OpId::BitwiseOr => write!(f, "{} | {}", args[0], args[1]), OpId::BitwiseXor => write!(f, "{} ^ {}", args[0], args[1]), + OpId::CountOnes => write!(f, "{}.count_ones()", args[0]), + OpId::ReverseBits => write!(f, "{}.reverse_bits()", args[0]), OpId::ShiftLeft => write!(f, "{} << {}", args[0], args[1]), OpId::ShiftRight => write!(f, "{} >> {}", args[0], args[1]), OpId::Remainder => write!(f, "{} % {}", args[0], args[1]), diff --git a/crates/cubecl-opt/src/gvn/base.rs b/crates/cubecl-opt/src/gvn/base.rs index b026913e7..1025337db 100644 --- a/crates/cubecl-opt/src/gvn/base.rs +++ b/crates/cubecl-opt/src/gvn/base.rs @@ -1,9 +1,6 @@ use std::collections::HashMap; -use cubecl_core::{ - ir::{Builtin, ConstantScalarValue, Elem, FloatKind, IntKind, Item, UIntKind}, - prelude::CubePrimitive, -}; +use cubecl_core::ir::{Builtin, ConstantScalarValue, Elem, FloatKind, IntKind, Item, UIntKind}; use float_ord::FloatOrd; use petgraph::graph::NodeIndex; use smallvec::SmallVec; @@ -142,7 +139,7 @@ impl Value { Value::Input(_, item) => *item, Value::Scalar(_, elem) => Item::new(*elem), Value::ConstArray(_, item, _) => *item, - Value::Builtin(_) => Item::new(u32::as_elem()), + Value::Builtin(_) => Item::new(Elem::UInt(UIntKind::U32)), Value::Output(_, item) => *item, Value::Slice(_, _, item) => *item, } @@ -230,6 +227,8 @@ pub enum OpId { BitwiseAnd, BitwiseOr, BitwiseXor, + CountOnes, + ReverseBits, ShiftLeft, ShiftRight, Remainder, diff --git a/crates/cubecl-opt/src/gvn/convert.rs b/crates/cubecl-opt/src/gvn/convert.rs index 5b7799a57..ee4981692 100644 --- a/crates/cubecl-opt/src/gvn/convert.rs +++ b/crates/cubecl-opt/src/gvn/convert.rs @@ -157,6 +157,10 @@ impl Expression { rhs: args[1], }) .into(), + OpId::CountOnes => Operator::CountOnes(UnaryOperator { input: args[0] }).into(), + OpId::ReverseBits => { + Operator::ReverseBits(UnaryOperator { input: args[0] }).into() + } OpId::ShiftLeft => Operator::ShiftLeft(BinaryOperator { lhs: args[0], rhs: args[1], @@ -217,7 +221,7 @@ impl Value { version: 0, item, }) => Variable::new( - VariableKind::LocalBinding { + VariableKind::LocalConst { id: *id, depth: *depth, }, @@ -272,7 +276,7 @@ pub fn value_of_var(var: &Variable) -> Option { version, item, }), - VariableKind::LocalBinding { id, depth } => Value::Local(Local { + VariableKind::LocalConst { id, depth } => Value::Local(Local { id, depth, version: 0, @@ -280,7 +284,7 @@ pub fn value_of_var(var: &Variable) -> Option { }), VariableKind::ConstantScalar(val) => Value::Constant(val.into()), VariableKind::ConstantArray { id, length } => Value::ConstArray(id, item, length), - VariableKind::Local { .. } + VariableKind::LocalMut { .. } | VariableKind::SharedMemory { .. } | VariableKind::LocalArray { .. } | VariableKind::Matrix { .. } => None?, @@ -331,6 +335,8 @@ pub fn id_of_op(op: &Operator) -> OpId { Operator::BitwiseAnd(_) => OpId::BitwiseAnd, Operator::BitwiseOr(_) => OpId::BitwiseOr, Operator::BitwiseXor(_) => OpId::BitwiseXor, + Operator::CountOnes(_) => OpId::CountOnes, + Operator::ReverseBits(_) => OpId::ReverseBits, Operator::ShiftLeft(_) => OpId::ShiftLeft, Operator::ShiftRight(_) => OpId::ShiftRight, Operator::Remainder(_) => OpId::Remainder, diff --git a/crates/cubecl-opt/src/gvn/numbering.rs b/crates/cubecl-opt/src/gvn/numbering.rs index a0c7d5d6d..edb0f3fee 100644 --- a/crates/cubecl-opt/src/gvn/numbering.rs +++ b/crates/cubecl-opt/src/gvn/numbering.rs @@ -3,9 +3,8 @@ use std::{ mem::swap, }; -use cubecl_core::{ - ir::{self, Item, Metadata, Operation, Operator, Variable, VariableKind}, - prelude::CubePrimitive, +use cubecl_core::ir::{ + self, Elem, Item, Metadata, Operation, Operator, UIntKind, Variable, VariableKind, }; use crate::PhiInstruction; @@ -211,7 +210,9 @@ impl ValueTable { | Operator::Not(op) | Operator::Neg(op) | Operator::Magnitude(op) - | Operator::Normalize(op) => { + | Operator::Normalize(op) + | Operator::CountOnes(op) + | Operator::ReverseBits(op) => { let input = self.lookup_or_add_var(&op.input)?; let item = out.item; let out = value_of_var(&out); @@ -320,7 +321,7 @@ impl ValueTable { | VariableKind::LocalArray { length, .. } => { let constant = length.into(); let num = self.lookup_or_add_var(&constant)?; - let expr = Expression::Copy(num, Item::new(u32::as_elem())); + let expr = Expression::Copy(num, Item::new(Elem::UInt(UIntKind::U32))); return Ok((expr, out)); } _ => unreachable!("Length only available on array"), diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 1bb1a1b5e..3378d0e85 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -109,7 +109,9 @@ impl Optimizer { | Operator::Cast(unary_operator) | Operator::Bitcast(unary_operator) | Operator::Magnitude(unary_operator) - | Operator::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read), + | Operator::Normalize(unary_operator) + | Operator::CountOnes(unary_operator) + | Operator::ReverseBits(unary_operator) => self.visit_unop(unary_operator, visit_read), Operator::Clamp(clamp_operator) => { visit_read(self, &mut clamp_operator.input); diff --git a/crates/cubecl-opt/src/lib.rs b/crates/cubecl-opt/src/lib.rs index d47dbda10..1fed23d90 100644 --- a/crates/cubecl-opt/src/lib.rs +++ b/crates/cubecl-opt/src/lib.rs @@ -299,7 +299,7 @@ impl Optimizer { let ops = self.program[node].ops.clone(); for op in ops.borrow().values() { if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation { - if let VariableKind::Local { id, depth } = &op.out().kind { + if let VariableKind::LocalMut { id, depth } = &op.out().kind { self.program.variables.remove(&(*id, *depth)); } } @@ -358,7 +358,7 @@ impl Optimizer { let processed = scope.process(); for var in processed.variables { - if let VariableKind::Local { id, depth } = var.kind { + if let VariableKind::LocalMut { id, depth } = var.kind { self.program.variables.insert((id, depth), var.item); } } @@ -411,7 +411,7 @@ impl Optimizer { /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise. pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<(u16, u8)> { match variable.kind { - core::VariableKind::Local { id, depth } if !variable.item.elem.is_atomic() => { + core::VariableKind::LocalMut { id, depth } if !variable.item.elem.is_atomic() => { Some((id, depth)) } _ => None, @@ -423,7 +423,7 @@ impl Optimizer { pub fn create_temporary(&self, item: Item) -> Variable { let next_id = self.program.temp_id.inc() as u16; Variable::new( - VariableKind::LocalBinding { + VariableKind::LocalConst { id: u16::MAX - next_id, depth: u8::MAX, }, @@ -453,10 +453,11 @@ pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {} #[cfg(test)] mod test { + use cubecl::prelude::*; use cubecl_core::{ self as cubecl, - ir::{HybridAllocator, Item, Variable, VariableKind}, - prelude::{Array, CubeContext, CubePrimitive, ExpandElement}, + ir::{Allocator, Elem, Item, UIntKind, Variable, VariableKind}, + prelude::{Array, CubeContext, ExpandElement}, }; use cubecl_core::{cube, CubeDim, ExecutionMode}; @@ -478,18 +479,18 @@ mod test { #[test] #[ignore = "no good way to assert opt is applied"] fn test_pre() { - let mut ctx = CubeContext::root(HybridAllocator::default()); + let mut ctx = CubeContext::root(Allocator::new()); let x = ExpandElement::Plain(Variable::new( VariableKind::GlobalScalar(0), - Item::new(u32::as_elem()), + Item::new(Elem::UInt(UIntKind::U32)), )); let cond = ExpandElement::Plain(Variable::new( VariableKind::GlobalScalar(1), - Item::new(u32::as_elem()), + Item::new(Elem::UInt(UIntKind::U32)), )); let arr = ExpandElement::Plain(Variable::new( VariableKind::GlobalOutputArray(0), - Item::new(u32::as_elem()), + Item::new(Elem::UInt(UIntKind::U32)), )); pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into()); diff --git a/crates/cubecl-opt/src/passes/array_copy_propagate.rs b/crates/cubecl-opt/src/passes/array_copy_propagate.rs index 8830faa90..82580ef2c 100644 --- a/crates/cubecl-opt/src/passes/array_copy_propagate.rs +++ b/crates/cubecl-opt/src/passes/array_copy_propagate.rs @@ -34,7 +34,7 @@ impl OptimizerPass for CopyPropagateArray { for Array { id, length, item } in arrays { let arr_id = id; let vars = (0..length) - .map(|_| opt.root_scope.create_local_undeclared(item)) + .map(|_| opt.root_scope.create_local_restricted(item)) .collect::>(); for var in &vars { let local_id = opt.local_variable_id(var).unwrap(); diff --git a/crates/cubecl-opt/src/passes/composite.rs b/crates/cubecl-opt/src/passes/composite.rs index c6ff4fe79..dc56fb11b 100644 --- a/crates/cubecl-opt/src/passes/composite.rs +++ b/crates/cubecl-opt/src/passes/composite.rs @@ -46,7 +46,7 @@ impl OptimizerPass for CompositeMerge { let op = { ops.borrow()[idx].clone() }; if let ( Operation::Operator(Operator::IndexAssign(BinaryOperator { lhs, rhs })), - Some(VariableKind::Local { id, depth }), + Some(VariableKind::LocalMut { id, depth }), ) = (op.operation, op.out.map(|it| it.kind)) { let item = op.out.unwrap().item; @@ -70,7 +70,7 @@ impl OptimizerPass for CompositeMerge { assert_eq!(index, 0, "Can't index into scalar"); opt.program[block].ops.borrow_mut()[idx] = Instruction::new( Operation::Copy(rhs), - Variable::new(VariableKind::Local { id, depth }, item), + Variable::new(VariableKind::LocalMut { id, depth }, item), ) } } @@ -93,7 +93,7 @@ fn merge_assigns( let last = assigns.iter().map(|it| it.0).max().unwrap(); assigns.sort_by_key(|it| it.1); let inputs = assigns.iter().map(|it| it.2).collect::>(); - let out = Variable::new(VariableKind::Local { id, depth }, item); + let out = Variable::new(VariableKind::LocalMut { id, depth }, item); ops.insert( last, Instruction::new(Operator::InitLine(LineInitOperator { inputs }), out), diff --git a/crates/cubecl-opt/src/passes/index_merge.rs b/crates/cubecl-opt/src/passes/index_merge.rs index 4b3e474ef..5053f5a0c 100644 --- a/crates/cubecl-opt/src/passes/index_merge.rs +++ b/crates/cubecl-opt/src/passes/index_merge.rs @@ -62,7 +62,7 @@ impl OptimizerPass for CopyTransform { fn as_versioned(var: &Variable) -> Option<(u16, u8, u16)> { match var.kind { - VariableKind::LocalBinding { id, depth } => Some((id, depth, 0)), + VariableKind::LocalConst { id, depth } => Some((id, depth, 0)), VariableKind::Versioned { id, depth, version } => Some((id, depth, version)), _ => None, } diff --git a/crates/cubecl-opt/src/passes/integer_range_analysis.rs b/crates/cubecl-opt/src/passes/integer_range_analysis.rs index 50f0ee347..2eced266e 100644 --- a/crates/cubecl-opt/src/passes/integer_range_analysis.rs +++ b/crates/cubecl-opt/src/passes/integer_range_analysis.rs @@ -1,8 +1,7 @@ use std::ops::{Add, Mul, Sub}; -use cubecl_core::{ - ir::{Builtin, ConstantScalarValue, Operation, Operator, Variable, VariableKind}, - prelude::CubePrimitive, +use cubecl_core::ir::{ + Builtin, ConstantScalarValue, Elem, Operation, Operator, UIntKind, Variable, VariableKind, }; use crate::{AtomicCounter, Optimizer, Range}; @@ -109,31 +108,35 @@ impl OptimizerPass for IntegerRangeAnalysis { /// can be determined, or the type is not an integer. pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { match var.kind { - VariableKind::Versioned { id, depth, version } if var.item.elem() == u32::as_elem() => opt - .program - .int_ranges - .get(&(id, depth, version)) - .copied() - .unwrap_or(Range { - lower_bound: Some(0), - upper_bound: None, - }), + VariableKind::Versioned { id, depth, version } + if var.item.elem() == Elem::UInt(UIntKind::U32) => + { + opt.program + .int_ranges + .get(&(id, depth, version)) + .copied() + .unwrap_or(Range { + lower_bound: Some(0), + upper_bound: None, + }) + } VariableKind::Versioned { id, depth, version } => opt .program .int_ranges .get(&(id, depth, version)) .copied() .unwrap_or_default(), - VariableKind::LocalBinding { id, depth } if var.item.elem() == u32::as_elem() => opt - .program - .int_ranges - .get(&(id, depth, 0)) - .copied() - .unwrap_or(Range { - lower_bound: Some(0), - upper_bound: None, - }), - VariableKind::LocalBinding { id, depth } => opt + VariableKind::LocalConst { id, depth } if var.item.elem() == Elem::UInt(UIntKind::U32) => { + opt.program + .int_ranges + .get(&(id, depth, 0)) + .copied() + .unwrap_or(Range { + lower_bound: Some(0), + upper_bound: None, + }) + } + VariableKind::LocalConst { id, depth } => opt .program .int_ranges .get(&(id, depth, 0)) @@ -161,7 +164,7 @@ pub(crate) fn range_of(opt: &Optimizer, var: &Variable) -> Range { pub(crate) fn var_id(var: &Variable) -> Option<(u16, u8, u16)> { match var.kind { VariableKind::Versioned { id, depth, version } => Some((id, depth, version)), - VariableKind::LocalBinding { id, depth } => Some((id, depth, 0)), + VariableKind::LocalConst { id, depth } => Some((id, depth, 0)), _ => None, } } diff --git a/crates/cubecl-opt/src/passes/reduce_strength.rs b/crates/cubecl-opt/src/passes/reduce_strength.rs index b50e941ac..1d42119a8 100644 --- a/crates/cubecl-opt/src/passes/reduce_strength.rs +++ b/crates/cubecl-opt/src/passes/reduce_strength.rs @@ -1,9 +1,6 @@ use std::mem::take; -use cubecl_core::{ - ir::{BinaryOperator, Instruction, Operation, Operator, Variable}, - prelude::CubePrimitive, -}; +use cubecl_core::ir::{BinaryOperator, Elem, Instruction, Operation, Operator, UIntKind, Variable}; use crate::{AtomicCounter, Optimizer}; @@ -38,7 +35,7 @@ impl OptimizerPass for ReduceStrength { } }; match op { - Operator::Mul(op) if inst.item().elem() == u32::as_elem() => { + Operator::Mul(op) if inst.item().elem() == Elem::UInt(UIntKind::U32) => { let (const_val, dyn_val) = match (op.lhs.as_const(), op.rhs.as_const()) { (None, Some(val)) => (val.as_u32(), op.lhs), (Some(val), None) => (val.as_u32(), op.rhs), @@ -146,7 +143,7 @@ impl OptimizerPass for ReduceStrength { } fn is_pow2(var: Variable) -> bool { - var.item.elem() == u32::as_elem() + var.item.elem() == Elem::UInt(UIntKind::U32) && var .as_const() .map(|it| it.as_u32().is_power_of_two()) diff --git a/crates/cubecl-opt/src/version.rs b/crates/cubecl-opt/src/version.rs index 749dfcd88..7c9636eda 100644 --- a/crates/cubecl-opt/src/version.rs +++ b/crates/cubecl-opt/src/version.rs @@ -175,7 +175,7 @@ impl Optimizer { fn version_writes(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) { self.visit_out(&mut op.out, |_, var| match var.kind { - VariableKind::Local { id, depth } | VariableKind::Versioned { id, depth, .. } => { + VariableKind::LocalMut { id, depth } | VariableKind::Versioned { id, depth, .. } => { if let Some(version) = state.versions.get_mut(&(id, depth)) { let max_version = state.max_versions.get_mut(&(id, depth)).unwrap(); *max_version += 1; @@ -196,7 +196,7 @@ impl Optimizer { fn version_read(&self, var: &mut Variable, state: &mut SsaState<'_>) { match var.kind { - VariableKind::Local { id, depth } | VariableKind::Versioned { id, depth, .. } => { + VariableKind::LocalMut { id, depth } | VariableKind::Versioned { id, depth, .. } => { if self.program.variables.contains_key(&(id, depth)) { if let Some(version) = state.versions.get(&(id, depth)) { *var = Variable::new( diff --git a/crates/cubecl-reduce/src/instructions/argmax.rs b/crates/cubecl-reduce/src/instructions/argmax.rs index 23ba64133..cb1976062 100644 --- a/crates/cubecl-reduce/src/instructions/argmax.rs +++ b/crates/cubecl-reduce/src/instructions/argmax.rs @@ -1,7 +1,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::{lowest_coordinate_matching, ArgAccumulator, ReduceInstruction}; +use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction}; /// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality. pub struct ArgMax; @@ -28,13 +28,17 @@ impl ArgMax { } } +impl Reduce for ArgMax { + type Instruction = Self; +} + #[cube] impl ReduceInstruction for ArgMax { type AccumulatorItem = (Line, Line); type SharedAccumulator = ArgAccumulator; fn null_input(#[comptime] line_size: u32) -> Line { - Line::empty(line_size).fill(In::MIN) + Line::empty(line_size).fill(In::min_value()) } fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem { @@ -83,7 +87,7 @@ impl ReduceInstruction for ArgMax { ) -> Out { let line_size = accumulator.0.size(); if comptime!(line_size > 1) { - let mut max = In::MIN.runtime(); + let mut max = In::min_value(); let mut coordinate = u32::MAX.runtime(); #[unroll] for k in 0..line_size { diff --git a/crates/cubecl-reduce/src/instructions/argmin.rs b/crates/cubecl-reduce/src/instructions/argmin.rs index cc1d210d6..abfc2692b 100644 --- a/crates/cubecl-reduce/src/instructions/argmin.rs +++ b/crates/cubecl-reduce/src/instructions/argmin.rs @@ -1,11 +1,15 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::{lowest_coordinate_matching, ArgAccumulator, ReduceInstruction}; +use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction}; /// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality. pub struct ArgMin; +impl Reduce for ArgMin { + type Instruction = Self; +} + #[cube] impl ArgMin { /// Compare two pairs of items and coordinates and return a new pair @@ -34,7 +38,7 @@ impl ReduceInstruction for ArgMin { type SharedAccumulator = ArgAccumulator; fn null_input(#[comptime] line_size: u32) -> Line { - Line::empty(line_size).fill(In::MAX) + Line::empty(line_size).fill(In::max_value()) } fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem { @@ -83,8 +87,9 @@ impl ReduceInstruction for ArgMin { ) -> Out { let line_size = accumulator.0.size(); if comptime!(line_size > 1) { - let mut min = In::MAX.runtime(); - let mut coordinate = 0; + let mut min = In::max_value(); + let mut coordinate = u32::MAX.runtime(); + #[unroll] for k in 0..line_size { let acc_element = accumulator.0[k]; diff --git a/crates/cubecl-reduce/src/instructions/base.rs b/crates/cubecl-reduce/src/instructions/base.rs index 8f91c58f9..5a46f12d5 100644 --- a/crates/cubecl-reduce/src/instructions/base.rs +++ b/crates/cubecl-reduce/src/instructions/base.rs @@ -1,6 +1,10 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; +pub trait Reduce: Send + Sync + 'static { + type Instruction: ReduceInstruction; +} + /// An instruction for a reduce algorithm that works with [`Line`]. /// /// See a provided implementation, such as [`Sum`] or [`ArgMax`] for an example how to implement diff --git a/crates/cubecl-reduce/src/instructions/mean.rs b/crates/cubecl-reduce/src/instructions/mean.rs index f99f1991c..3afaa4b75 100644 --- a/crates/cubecl-reduce/src/instructions/mean.rs +++ b/crates/cubecl-reduce/src/instructions/mean.rs @@ -1,10 +1,14 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::{ReduceInstruction, Sum}; +use super::{Reduce, ReduceInstruction, Sum}; pub struct Mean; +impl Reduce for Mean { + type Instruction = Self; +} + #[cube] impl ReduceInstruction for Mean { type AccumulatorItem = Line; diff --git a/crates/cubecl-reduce/src/instructions/prod.rs b/crates/cubecl-reduce/src/instructions/prod.rs index 27cca40bd..a8142d157 100644 --- a/crates/cubecl-reduce/src/instructions/prod.rs +++ b/crates/cubecl-reduce/src/instructions/prod.rs @@ -1,10 +1,14 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::ReduceInstruction; +use super::{Reduce, ReduceInstruction}; pub struct Prod; +impl Reduce for Prod { + type Instruction = Self; +} + #[cube] impl ReduceInstruction for Prod { type AccumulatorItem = Line; diff --git a/crates/cubecl-reduce/src/instructions/sum.rs b/crates/cubecl-reduce/src/instructions/sum.rs index 63c965ce3..acc2550dd 100644 --- a/crates/cubecl-reduce/src/instructions/sum.rs +++ b/crates/cubecl-reduce/src/instructions/sum.rs @@ -1,10 +1,14 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use super::ReduceInstruction; +use super::{Reduce, ReduceInstruction}; pub struct Sum; +impl Reduce for Sum { + type Instruction = Self; +} + #[cube] impl ReduceInstruction for Sum { type AccumulatorItem = Line; diff --git a/crates/cubecl-reduce/src/launch.rs b/crates/cubecl-reduce/src/launch.rs index c7763bae4..9ee3e2be3 100644 --- a/crates/cubecl-reduce/src/launch.rs +++ b/crates/cubecl-reduce/src/launch.rs @@ -8,7 +8,7 @@ use crate::{LineMode, ReduceConfig, ReduceStrategy}; /// Launch a reduce kernel. This function assumes that all parameters are already validated. /// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function /// with the appropriate assumptions. -pub(crate) fn launch_reduce>( +pub(crate) fn launch_reduce( client: &ComputeClient, input: TensorHandleRef, output: TensorHandleRef, @@ -52,7 +52,7 @@ struct ReduceParams { } #[cube(launch_unchecked)] -fn reduce_kernel>( +fn reduce_kernel( input: &Tensor>, output: &mut Tensor, axis_reduce: u32, @@ -75,7 +75,7 @@ fn reduce_kernel>( let accumulator = match comptime!((params.shared, params.use_planes)) { (Some(accumulator_size), use_planes) => { - let mut accumulator = reduce_slice_shared::( + let mut accumulator = reduce_slice_shared::>( input.to_slice(), range, accumulator_size, @@ -84,18 +84,24 @@ fn reduce_kernel>( use_planes, ); sync_units(); - reduce_tree::(&mut accumulator, accumulator_size) - } - (None, true) => { - reduce_slice_plane::(input.to_slice(), range, params.line_size, params.line_mode) - } - (None, false) => { - reduce_slice::(input.to_slice(), range, params.line_size, params.line_mode) + reduce_tree::>(&mut accumulator, accumulator_size) } + (None, true) => reduce_slice_plane::>( + input.to_slice(), + range, + params.line_size, + params.line_mode, + ), + (None, false) => reduce_slice::>( + input.to_slice(), + range, + params.line_size, + params.line_mode, + ), }; if elected_writer(params) { - write_to_output::( + write_to_output::>( output, accumulator, reduce_index, diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index 71b65e48b..194619154 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -20,6 +20,7 @@ mod strategy; pub use config::*; pub use error::*; +use instructions::Reduce; pub use instructions::ReduceInstruction; pub use strategy::*; @@ -83,7 +84,7 @@ use cubecl_core::prelude::*; /// println!("Output = {:?}", output_values); // Should print [1, 5]. /// } /// ``` -pub fn reduce>( +pub fn reduce( client: &ComputeClient, input: TensorHandleRef, output: TensorHandleRef, diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index df2c8639f..c590da87a 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -7,7 +7,7 @@ use rand::{ SeedableRng, }; -use crate::{instructions::*, reduce, ReduceError, ReduceInstruction, ReduceStrategy}; +use crate::{instructions::*, reduce, ReduceError, ReduceStrategy}; // All random values generated for tests will be in the set // {-2, -2 + E, -2 + 2E, ..., 2 - E, 2} with E = 1 / PRECISION. @@ -245,7 +245,7 @@ impl TestCase { } fn cpu_argmax(&self, values: &[F]) -> Vec { - let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()]; + let mut expected = vec![(F::min_value(), 0_u32); self.num_output_values()]; for (input_index, &value) in values.iter().enumerate() { let output_index = self.to_output_index(input_index); let (best, _) = expected[output_index]; @@ -268,7 +268,7 @@ impl TestCase { } fn cpu_argmin(&self, values: &[F]) -> Vec { - let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()]; + let mut expected = vec![(F::max_value(), 0_u32); self.num_output_values()]; for (input_index, &value) in values.iter().enumerate() { let output_index = self.to_output_index(input_index); let (best, _) = expected[output_index]; @@ -346,7 +346,7 @@ impl TestCase { I: Numeric + CubeElement + std::fmt::Display, O: Numeric + CubeElement + std::fmt::Display, R: Runtime, - K: ReduceInstruction, + K: Reduce, { let client = R::client(device); diff --git a/crates/cubecl-spirv/src/compiler.rs b/crates/cubecl-spirv/src/compiler.rs index 32b1a308b..30e6de357 100644 --- a/crates/cubecl-spirv/src/compiler.rs +++ b/crates/cubecl-spirv/src/compiler.rs @@ -9,7 +9,7 @@ use std::{ }; use cubecl_core::{ - ir::{HybridAllocator, KernelDefinition, LocalAllocator}, + ir::{Allocator, KernelDefinition}, Compiler, ExecutionMode, }; use rspirv::{ @@ -159,8 +159,8 @@ impl Compiler for SpirvCompiler { elem.size() } - fn local_allocator() -> impl LocalAllocator { - HybridAllocator::default() + fn local_allocator() -> Allocator { + Allocator::new() } fn max_shared_memory_size() -> usize { diff --git a/crates/cubecl-spirv/src/debug.rs b/crates/cubecl-spirv/src/debug.rs index a05c671d0..c3145d95f 100644 --- a/crates/cubecl-spirv/src/debug.rs +++ b/crates/cubecl-spirv/src/debug.rs @@ -539,6 +539,7 @@ impl SpirvCompiler { .unwrap() } + #[track_caller] pub fn debug_info(&mut self) -> &mut DebugInfo { self.debug_info.as_mut().unwrap() } diff --git a/crates/cubecl-spirv/src/instruction.rs b/crates/cubecl-spirv/src/instruction.rs index 8c721607f..e3e7de615 100644 --- a/crates/cubecl-spirv/src/instruction.rs +++ b/crates/cubecl-spirv/src/instruction.rs @@ -433,6 +433,18 @@ impl SpirvCompiler { b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap(); }) } + Operator::CountOnes(op) => { + // While the spec theoretically allows arbitrary integers, Vulkan only supports i32/u32 + self.compile_unary_op_cast(op, out, |b, _, ty, input, out| { + b.bit_count(ty, Some(out), input).unwrap(); + }); + } + Operator::ReverseBits(op) => { + self.capabilities.insert(Capability::BitInstructions); + self.compile_unary_op(op, out, |b, _, ty, input, out| { + b.bit_reverse(ty, Some(out), input).unwrap(); + }); + } Operator::ShiftLeft(op) => { self.compile_binary_op(op, out, |b, _, ty, lhs, rhs, out| { b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap(); diff --git a/crates/cubecl-spirv/src/item.rs b/crates/cubecl-spirv/src/item.rs index 65726ee83..5fbfdfef4 100644 --- a/crates/cubecl-spirv/src/item.rs +++ b/crates/cubecl-spirv/src/item.rs @@ -159,42 +159,6 @@ impl Item { } }; - let swap_sign = |b: &mut SpirvCompiler, - obj: Word, - out_id: Option, - width: u32, - target_sign: bool| match (width, target_sign) { - (_, false) => { - let zero = self.const_u32(b, 0); - let id = out_id.unwrap_or_else(|| b.id()); - let ty = self.id(b); - T::s_max(b, ty, obj, zero, id); - id - } - (64, true) => { - let max = ConstVal::Bit64(i64::MAX as u64); - let max = b.static_cast(max, &Elem::Int(64, true), self); - let id = out_id.unwrap_or_else(|| b.id()); - let ty = self.id(b); - T::u_min(b, ty, obj, max, id); - id - } - (width, true) => { - let max = match width { - 32 => i32::MAX as u32, - 16 => i16::MAX as u32, - 8 => i8::MAX as u32, - _ => unimplemented!("Invalid width"), - }; - let max = ConstVal::Bit32(max); - let max = b.static_cast(max, &Elem::Int(32, true), self); - let id = out_id.unwrap_or_else(|| b.id()); - let ty = self.id(b); - T::u_min(b, ty, obj, max, id); - id - } - }; - let convert_i_width = |b: &mut SpirvCompiler, obj: Word, out_id: Option, signed: bool| { if signed { @@ -209,16 +173,11 @@ impl Item { out_id: Option, (width_self, signed_self), (width_other, signed_other)| { - let sign_differs = signed_self != signed_other; let width_differs = width_self != width_other; - match (sign_differs, width_differs) { - (true, true) => { - let sign_swap = swap_sign(b, obj, None, width_self, signed_other); - convert_i_width(b, sign_swap, out_id, signed_other) - } - (true, false) => swap_sign(b, obj, out_id, width_self, signed_other), - (false, true) => convert_i_width(b, obj, out_id, signed_other), - (false, false) => b.copy_object(ty, out_id, obj).unwrap(), + let sign_extend = signed_self && signed_other; + match width_differs { + true => convert_i_width(b, obj, out_id, sign_extend), + false => b.copy_object(ty, out_id, obj).unwrap(), } }; diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 6d382fbbc..3e12b224f 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -378,7 +378,7 @@ impl SpirvCompiler { Variable::GlobalOutputArray(id, self.compile_item(item), pos) } core::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.elem), - core::VariableKind::Local { id, depth } => { + core::VariableKind::LocalMut { id, depth } => { let item = self.compile_item(item); let var = self.get_local((id, depth), &item); Variable::Local { id: var, item } @@ -388,7 +388,7 @@ impl SpirvCompiler { let id = (id, depth, version); Variable::Versioned { id, item } } - core::VariableKind::LocalBinding { id, depth } => { + core::VariableKind::LocalConst { id, depth } => { let item = self.compile_item(item); let id = (id, depth); Variable::LocalBinding { id, item } diff --git a/crates/cubecl-wgpu/src/compiler/spirv.rs b/crates/cubecl-wgpu/src/compiler/spirv.rs index 91051dd65..cdf6d419e 100644 --- a/crates/cubecl-wgpu/src/compiler/spirv.rs +++ b/crates/cubecl-wgpu/src/compiler/spirv.rs @@ -4,9 +4,10 @@ use ash::{ khr::cooperative_matrix, vk::{ ComponentTypeKHR, DeviceCreateInfo, DeviceQueueCreateInfo, - PhysicalDevice16BitStorageFeatures, PhysicalDeviceCooperativeMatrixFeaturesKHR, - PhysicalDeviceShaderFloat16Int8Features, PhysicalDeviceVulkanMemoryModelFeatures, ScopeKHR, - EXT_ROBUSTNESS2_NAME, KHR_COOPERATIVE_MATRIX_NAME, + PhysicalDevice16BitStorageFeatures, PhysicalDevice8BitStorageFeatures, + PhysicalDeviceCooperativeMatrixFeaturesKHR, PhysicalDeviceShaderFloat16Int8Features, + PhysicalDeviceVulkanMemoryModelFeatures, ScopeKHR, EXT_ROBUSTNESS2_NAME, + KHR_COOPERATIVE_MATRIX_NAME, }, }; use cubecl_core::{ @@ -229,6 +230,7 @@ fn request_device( .shader_int8(true); let mut buf_16 = PhysicalDevice16BitStorageFeatures::default().storage_buffer16_bit_access(true); + let mut buf_8 = PhysicalDevice8BitStorageFeatures::default().storage_buffer8_bit_access(true); if has_cmma { device_extensions.push(KHR_COOPERATIVE_MATRIX_NAME); @@ -265,6 +267,7 @@ fn request_device( info = info.push_next(&mut mem_model); info = info.push_next(&mut f16_i8); info = info.push_next(&mut buf_16); + info = info.push_next(&mut buf_8); if let Some(cmma) = &mut cmma { info = info.push_next(cmma); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs index 0b7d1dfc1..67e2a3817 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs @@ -7,14 +7,15 @@ pub enum Variable { GlobalOutputArray(u16, Item), GlobalScalar(u16, Elem, cube::Elem), ConstantScalar(ConstantScalarValue, Elem), - Local { + LocalMut { id: u16, item: Item, depth: u8, }, - LocalBinding { + LocalConst { id: u16, item: Item, + depth: u8, }, Named { name: String, @@ -98,8 +99,8 @@ impl Variable { Variable::SharedMemory(_, _, _) => false, Variable::ConstantArray(_, _, _) => false, Variable::LocalArray(_, _, _, _) => false, - Variable::Local { .. } => false, - Variable::LocalBinding { .. } => false, + Variable::LocalMut { .. } => false, + Variable::LocalConst { .. } => false, Variable::Named { .. } => false, Variable::Slice { .. } => false, Variable::WorkgroupIdX => true, @@ -132,7 +133,7 @@ impl Variable { Variable::GlobalInputArray(_, item) => item.elem().is_atomic(), Variable::GlobalOutputArray(_, item) => item.elem().is_atomic(), Variable::GlobalScalar(_, elem, _) => elem.is_atomic(), - Variable::Local { item, .. } => item.elem().is_atomic(), + Variable::LocalMut { item, .. } => item.elem().is_atomic(), Variable::Named { item, .. } => item.elem().is_atomic(), Variable::Slice { item, .. } => item.elem().is_atomic(), Variable::LocalScalar { elem, .. } => elem.is_atomic(), @@ -149,8 +150,8 @@ impl Variable { Self::SharedMemory(_, e, _) => *e, Self::ConstantArray(_, e, _) => *e, Self::LocalArray(_, e, _, _) => *e, - Self::Local { item, .. } => *item, - Self::LocalBinding { item, .. } => *item, + Self::LocalMut { item, .. } => *item, + Self::LocalConst { item, .. } => *item, Self::Slice { item, .. } => *item, Self::Named { item, .. } => *item, Self::ConstantScalar(_, e) => Item::Scalar(*e), @@ -279,12 +280,12 @@ impl Display for Variable { depth: scope_depth, .. } => write!(f, "s_{scope_depth}_{index}"), - Variable::Local { + Variable::LocalMut { id: index, depth: scope_depth, .. - } => write!(f, "l_{scope_depth}_{index}"), - Variable::LocalBinding { id: index, .. } => write!(f, "_{index}"), + } => write!(f, "l_mut_{scope_depth}_{index}"), + Variable::LocalConst { id, depth, .. } => write!(f, "l_{depth}_{id}"), Variable::Named { name, .. } => f.write_str(name), Variable::Slice { id: index, @@ -301,7 +302,7 @@ impl Display for Variable { // precision related problems. Variable::ConstantScalar(number, _elem) => match number { ConstantScalarValue::Int(val, kind) => match kind { - IntKind::I32 => write!(f, "{}i", *val as i32), + IntKind::I32 => write!(f, "{}", *val as i32), _ => unimplemented!("{:?} not supported in WGSL", kind), }, ConstantScalarValue::Float(val, kind) => match kind { @@ -365,8 +366,8 @@ impl Display for IndexedVariable { impl Variable { pub fn fmt_left(&self) -> String { match self { - Variable::LocalBinding { id, .. } => { - format!("let _{id}") + Variable::LocalConst { .. } => { + format!("let {self}") } var => format!("{}", var), } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 4cfe13b54..970bd19d1 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -8,11 +8,10 @@ use crate::{ WgpuServer, }; -use cubecl_core::ir::{expand_checked_index, expand_checked_index_assign}; +use cubecl_core::ir::expand_checked_index_assign; use cubecl_core::{ - ir::{self as cube, HybridAllocator, UIntKind}, + ir::{self as cube, Allocator, UIntKind}, prelude::CompiledKernel, - prelude::CubePrimitive, server::ComputeServer, Feature, Metadata, }; @@ -82,8 +81,8 @@ impl cubecl_core::Compiler for WgslCompiler { 32768 } - fn local_allocator() -> impl cube::LocalAllocator { - HybridAllocator::default() + fn local_allocator() -> Allocator { + Allocator::new() } } @@ -358,15 +357,16 @@ impl WgslCompiler { cube::VariableKind::GlobalScalar(id) => { wgsl::Variable::GlobalScalar(id, Self::compile_elem(item.elem), item.elem) } - cube::VariableKind::Local { id, depth } - | cube::VariableKind::Versioned { id, depth, .. } => wgsl::Variable::Local { + cube::VariableKind::LocalMut { id, depth } + | cube::VariableKind::Versioned { id, depth, .. } => wgsl::Variable::LocalMut { id, item: Self::compile_item(item), depth, }, - cube::VariableKind::LocalBinding { id, .. } => wgsl::Variable::LocalBinding { + cube::VariableKind::LocalConst { id, depth } => wgsl::Variable::LocalConst { id, item: Self::compile_item(item), + depth, }, cube::VariableKind::Slice { id, depth } => wgsl::Variable::Slice { id, @@ -891,18 +891,32 @@ impl WgslCompiler { out: self.compile_variable(out), }), cube::Operator::Index(op) => { - if let ExecutionMode::Checked = self.strategy { - if op.lhs.has_length() { - expand_checked_index(scope, op.lhs, op.rhs, out); - instructions.extend(self.compile_scope(scope)); - return; - } - }; - instructions.push(wgsl::Instruction::Index { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(out), - }); + if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() { + let lhs = op.lhs; + let rhs = op.rhs; + let array_len = + scope.create_local(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32))); + + instructions.extend(self.compile_scope(scope)); + + let length = match lhs.has_buffer_length() { + true => cube::Metadata::BufferLength { var: lhs }, + false => cube::Metadata::Length { var: lhs }, + }; + instructions.push(self.compile_metadata(length, Some(array_len))); + instructions.push(wgsl::Instruction::CheckedIndex { + len: self.compile_variable(array_len), + lhs: self.compile_variable(lhs), + rhs: self.compile_variable(rhs), + out: self.compile_variable(out), + }); + } else { + instructions.push(wgsl::Instruction::Index { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }); + } } cube::Operator::UncheckedIndex(op) => instructions.push(wgsl::Instruction::Index { lhs: self.compile_variable(op.lhs), @@ -959,6 +973,14 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), }), + cube::Operator::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Operator::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Operator::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), @@ -977,8 +999,8 @@ impl WgslCompiler { cube::Operator::Slice(op) => { if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() { let input = op.input; - let input_len = scope.create_local(cube::Item::new(u32::as_elem())); - + let input_len = scope + .create_local_mut(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32))); instructions.extend(self.compile_scope(scope)); let length = match input.has_buffer_length() { diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 84c8c5dcc..9f3727f61 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -68,6 +68,13 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + // Index handles casting to correct local variable. + CheckedIndex { + len: Variable, + lhs: Variable, + rhs: Variable, + out: Variable, + }, // Assign handle casting to correct output variable. Assign { input: Variable, @@ -227,6 +234,14 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + CountBits { + input: Variable, + out: Variable, + }, + ReverseBits { + input: Variable, + out: Variable, + }, ShiftLeft { lhs: Variable, rhs: Variable, @@ -480,6 +495,22 @@ impl Display for Instruction { index_assign(f, lhs, rhs, out, None) } } + Instruction::CheckedIndex { len, lhs, rhs, out } => match lhs { + Variable::Slice { item, .. } => { + let offset = Variable::Named { + name: format!("{lhs}_offset"), + item: Item::Scalar(Elem::U32), + is_array: false, + }; + let lhs = Variable::Named { + name: format!("(*{lhs}_ptr)"), + item: *item, + is_array: true, + }; + index(f, &lhs, rhs, out, Some(offset), Some(len)) + } + _ => index(f, lhs, rhs, out, None, Some(len)), + }, Instruction::Copy { input, in_index, @@ -794,6 +825,14 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ let out = out.fmt_left(); writeln!(f, "{out} = {lhs} ^ {rhs};") } + Instruction::CountBits { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = countOneBits({input});") + } + Instruction::ReverseBits { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = reverseBits({input});") + } Instruction::ShiftLeft { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} << {rhs};") @@ -1043,8 +1082,8 @@ fn index( len: Option<&Variable>, ) -> core::fmt::Result { let is_scalar = match lhs { - Variable::Local { item, .. } => item.vectorization_factor() == 1, - Variable::LocalBinding { item, .. } => item.vectorization_factor() == 1, + Variable::LocalMut { item, .. } => item.vectorization_factor() == 1, + Variable::LocalConst { item, .. } => item.vectorization_factor() == 1, _ => false, }; diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index f54be579e..50dea369c 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -25,7 +25,8 @@ mod tests { pub type TestRuntime = crate::WgpuRuntime; cubecl_core::testgen_all!(); - cubecl_linalg::testgen_matmul_plane_mma!([flex32, f32], f32); + cubecl_linalg::testgen_matmul_plane!([f32]); + cubecl_linalg::testgen_matmul_accelerated!([f32]); cubecl_linalg::testgen_matmul_tiling2d!([flex32, f32]); cubecl_linalg::testgen_matmul_simple!([flex32, f32]); cubecl_reduce::testgen_reduce!(); @@ -38,8 +39,8 @@ mod tests_spirv { use half::f16; cubecl_core::testgen_all!(f32: [f16, flex32, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); - cubecl_linalg::testgen_matmul_plane_mma!([f16, flex32, f32], f32); + cubecl_linalg::testgen_matmul_plane!([f16, flex32, f32]); cubecl_linalg::testgen_matmul_tiling2d!([f16, flex32, f32, f64]); cubecl_linalg::testgen_matmul_simple!([flex32, f32]); - cubecl_linalg::testgen_matmul_cmma!([f16]); + cubecl_linalg::testgen_matmul_accelerated!([f16]); } diff --git a/crates/cubecl-wgpu/tests/common.rs b/crates/cubecl-wgpu/tests/common.rs deleted file mode 100644 index f602d2c1e..000000000 --- a/crates/cubecl-wgpu/tests/common.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::num::NonZero; - -use cubecl_core::{ - prelude::{ArrayCompilationArg, TensorCompilationArg}, - Compiler, CubeDim, ExecutionMode, Kernel, KernelSettings, Runtime, -}; -use cubecl_wgpu::{WgpuRuntime, WgslCompiler}; - -pub type TestRuntime = WgpuRuntime; - -pub fn settings(dim_x: u32, dim_y: u32) -> KernelSettings { - KernelSettings::default().cube_dim(CubeDim::new(dim_x, dim_y, 1)) -} - -#[allow(unused)] -pub fn tensor() -> TensorCompilationArg { - TensorCompilationArg { - inplace: None, - vectorisation: NonZero::new(1), - } -} - -#[allow(unused)] -pub fn tensor_vec(vectorization: u8) -> TensorCompilationArg { - TensorCompilationArg { - inplace: None, - vectorisation: NonZero::new(vectorization), - } -} - -#[allow(unused)] -pub fn array() -> ArrayCompilationArg { - ArrayCompilationArg { - inplace: None, - vectorisation: NonZero::new(1), - } -} - -pub fn compile(kernel: impl Kernel) -> String { - let comp_opts = Default::default(); - <::Compiler as Compiler>::compile( - kernel.define(), - &comp_opts, - ExecutionMode::Checked, - ) - .to_string() -} - -#[macro_export] -macro_rules! load_kernel_string { - ($file:expr) => { - include_str!($file) - .replace("\r\n", "\n") - .trim_end() - .to_string() - }; -} diff --git a/crates/cubecl-wgpu/tests/constant_array.wgsl b/crates/cubecl-wgpu/tests/constant_array.wgsl deleted file mode 100644 index c52ced70c..000000000 --- a/crates/cubecl-wgpu/tests/constant_array.wgsl +++ /dev/null @@ -1,34 +0,0 @@ -@group(0) -@binding(0) -var output_0_global: array; - -@group(0) -@binding(1) -var info: array; - -const arrays_0: array = array(f32(3u),f32(5u),f32(1u),); - -const WORKGROUP_SIZE_X = 16u; -const WORKGROUP_SIZE_Y = 16u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(16, 16, 1) -fn constant_array_kernel_f32( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { -let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; -let _0 = info[1u]; -let _1 = id < _0; -if _1 { -let _2 = arrays_0[id]; -var l_1_0: u32; -var l_1_1: bool; -l_1_0 = info[0u]; -l_1_1 = id < l_1_0; -if l_1_1 { -output_0_global[id] = _2; -} -} -} diff --git a/crates/cubecl-wgpu/tests/main.rs b/crates/cubecl-wgpu/tests/main.rs deleted file mode 100644 index 2fe2322c8..000000000 --- a/crates/cubecl-wgpu/tests/main.rs +++ /dev/null @@ -1,149 +0,0 @@ -use common::*; -use constant_array_kernel::ConstantArrayKernel; -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_wgpu::WgpuRuntime; -use execute_unary_kernel::ExecuteUnaryKernel; -use half::bf16; -use kernel_elect::KernelElect; -use kernel_sum::KernelSum; -use naming_kernel::NamingKernel; -use pretty_assertions::assert_eq; -use sequence_for_loop_kernel::SequenceForLoopKernel; -use slice_assign_kernel::SliceAssignKernel; - -mod common; - -#[cube(launch_unchecked, create_dummy_kernel)] -pub fn slice_assign_kernel(input: &Tensor, output: &mut Tensor) { - if UNIT_POS == 0 { - let mut slice_1 = output.slice_mut(2, 3); - slice_1[0] = input[0]; - } -} - -#[test] -pub fn slice_assign() { - let kernel = SliceAssignKernel::::new(settings(1, 1), tensor(), tensor()); - let expected = load_kernel_string!("slice_assign.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} - -#[cube(launch, create_dummy_kernel)] -pub fn kernel_sum(output: &mut Tensor) { - let val = output[UNIT_POS]; - let val2 = cubecl_core::prelude::plane_sum(val); - - if UNIT_POS == 0 { - output[0] = val2; - } -} - -#[test] -pub fn plane_sum() { - let kernel = KernelSum::::new(settings(4, 1), tensor()); - let expected = load_kernel_string!("plane_sum.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} - -#[cube(launch, create_dummy_kernel)] -pub fn kernel_elect(output: &mut Tensor) { - let elected = cubecl_core::prelude::plane_elect(); - output[UNIT_POS] = elected as u32; -} - -#[test] -pub fn plane_elect() { - let kernel = KernelElect::::new(settings(4, 1), tensor()); - let expected = load_kernel_string!("plane_elect.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} - -#[cube(launch, create_dummy_kernel)] -pub fn sequence_for_loop_kernel(output: &mut Array) { - if UNIT_POS != 0 { - return; - } - - let mut sequence = Sequence::::new(); - sequence.push(1.0); - sequence.push(4.0); - - for value in sequence { - output[0] += value; - } -} - -#[test] -pub fn sequence_for_loop() { - let kernel = SequenceForLoopKernel::::new(settings(16, 16), array()); - let expected = load_kernel_string!("sequence_for_loop.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} - -#[cube(launch, create_dummy_kernel)] -fn execute_unary_kernel(lhs: &Tensor, rhs: &Tensor, out: &mut Tensor) { - if ABSOLUTE_POS < out.len() { - for i in 0..256u32 { - if i % 2 == 0 { - out[ABSOLUTE_POS] -= F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); - } else { - out[ABSOLUTE_POS] += F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); - } - } - } -} - -#[test] -pub fn unary_bench() { - let kernel = ExecuteUnaryKernel::::new( - settings(16, 16), - tensor_vec(4), - tensor_vec(4), - tensor_vec(4), - ); - let expected = load_kernel_string!("unary_bench.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} - -#[cube(launch, create_dummy_kernel)] -fn constant_array_kernel(out: &mut Tensor, #[comptime] data: Vec) { - let array = Array::::from_data(data); - - if ABSOLUTE_POS < out.len() { - out[ABSOLUTE_POS] = array[ABSOLUTE_POS]; - } -} - -#[test] -pub fn constant_array() { - let data: Vec = vec![3, 5, 1]; - - let kernel = ConstantArrayKernel::::new(settings(16, 16), tensor(), data); - let expected = load_kernel_string!("constant_array.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} - -// This kernel just exists to have a few generics in order to observe -// that the generics get propagated into the WGSL kernel name -#[allow(clippy::extra_unused_type_parameters)] -#[cube(launch, create_dummy_kernel)] -fn naming_kernel(out: &mut Array) { - if ABSOLUTE_POS < out.len() { - out[ABSOLUTE_POS] = F1::from_int(0); - } -} - -#[test] -pub fn naming() { - let kernel = NamingKernel::::new(settings(16, 16), array()); - let expected = load_kernel_string!("naming.wgsl"); - let compiled = compile(kernel); - assert_eq!(compiled, expected); -} diff --git a/crates/cubecl-wgpu/tests/naming.wgsl b/crates/cubecl-wgpu/tests/naming.wgsl deleted file mode 100644 index 9b70b5c11..000000000 --- a/crates/cubecl-wgpu/tests/naming.wgsl +++ /dev/null @@ -1,31 +0,0 @@ -@group(0) -@binding(0) -var output_0_global: array; - -@group(0) -@binding(1) -var info: array; - -const WORKGROUP_SIZE_X = 16u; -const WORKGROUP_SIZE_Y = 16u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(16, 16, 1) -fn naming_kernel_f32_u8_bf16_i64( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { -let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; -let _0 = info[1u]; -let _1 = id < _0; -if _1 { -var l_1_0: u32; -var l_1_1: bool; -l_1_0 = info[0u]; -l_1_1 = id < l_1_0; -if l_1_1 { -output_0_global[id] = 0f; -} -} -} \ No newline at end of file diff --git a/crates/cubecl-wgpu/tests/plane_elect.wgsl b/crates/cubecl-wgpu/tests/plane_elect.wgsl deleted file mode 100644 index ee899265f..000000000 --- a/crates/cubecl-wgpu/tests/plane_elect.wgsl +++ /dev/null @@ -1,27 +0,0 @@ -@group(0) -@binding(0) -var output_0_global: array; - -@group(0) -@binding(1) -var info: array; - -const WORKGROUP_SIZE_X = 4u; -const WORKGROUP_SIZE_Y = 1u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(4, 1, 1) -fn kernel_elect( - @builtin(local_invocation_index) local_idx: u32, -) { -let _0 = subgroupElect(); -let _1 = u32(_0); -var l_0_0: u32; -var l_0_1: bool; -l_0_0 = info[0u]; -l_0_1 = local_idx < l_0_0; -if l_0_1 { -output_0_global[local_idx] = _1; -} -} diff --git a/crates/cubecl-wgpu/tests/plane_sum.wgsl b/crates/cubecl-wgpu/tests/plane_sum.wgsl deleted file mode 100644 index 0383fb56a..000000000 --- a/crates/cubecl-wgpu/tests/plane_sum.wgsl +++ /dev/null @@ -1,36 +0,0 @@ -@group(0) -@binding(0) -var output_0_global: array; - -@group(0) -@binding(1) -var info: array; - -const WORKGROUP_SIZE_X = 4u; -const WORKGROUP_SIZE_Y = 1u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(4, 1, 1) -fn kernel_sum( - @builtin(local_invocation_index) local_idx: u32, -) { -var l_0_0: u32; -var l_0_1: bool; -var l_0_2: f32; -l_0_0 = info[0u]; -l_0_1 = local_idx < l_0_0; -l_0_2 = output_0_global[local_idx]; -let _0 = select(0f, l_0_2, l_0_1); -let _1 = subgroupAdd(_0); -let _2 = local_idx == 0u; -if _2 { -var l_1_0: u32; -var l_1_1: bool; -l_1_0 = info[0u]; -l_1_1 = 0u < l_1_0; -if l_1_1 { -output_0_global[0u] = _1; -} -} -} diff --git a/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl b/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl deleted file mode 100644 index 64c9b0def..000000000 --- a/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl +++ /dev/null @@ -1,52 +0,0 @@ -@group(0) -@binding(0) -var output_0_global: array; - -@group(0) -@binding(1) -var info: array; - -const WORKGROUP_SIZE_X = 16u; -const WORKGROUP_SIZE_Y = 16u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(16, 16, 1) -fn sequence_for_loop_kernel( - @builtin(local_invocation_index) local_idx: u32, -) { -let _0 = local_idx != 0u; -if _0 { -return; -} -var l_0_0: u32; -var l_0_1: bool; -var l_0_2: f32; -l_0_0 = info[0u]; -l_0_1 = 0u < l_0_0; -l_0_2 = output_0_global[0u]; -let _1 = select(0f, l_0_2, l_0_1); -let _2 = _1 + 1f; -var l_0_3: u32; -var l_0_4: bool; -l_0_3 = info[0u]; -l_0_4 = 0u < l_0_3; -if l_0_4 { -output_0_global[0u] = _2; -} -var l_0_5: u32; -var l_0_6: bool; -var l_0_7: f32; -l_0_5 = info[0u]; -l_0_6 = 0u < l_0_5; -l_0_7 = output_0_global[0u]; -let _3 = select(0f, l_0_7, l_0_6); -let _4 = _3 + 4f; -var l_0_8: u32; -var l_0_9: bool; -l_0_8 = info[0u]; -l_0_9 = 0u < l_0_8; -if l_0_9 { -output_0_global[0u] = _4; -} -} diff --git a/crates/cubecl-wgpu/tests/slice_assign.wgsl b/crates/cubecl-wgpu/tests/slice_assign.wgsl deleted file mode 100644 index ae6a3cc65..000000000 --- a/crates/cubecl-wgpu/tests/slice_assign.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -@group(0) -@binding(0) -var input_0_global: array; - -@group(0) -@binding(1) -var output_0_global: array; - -@group(0) -@binding(2) -var info: array; - -const WORKGROUP_SIZE_X = 1u; -const WORKGROUP_SIZE_Y = 1u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(1, 1, 1) -fn slice_assign_kernel( - @builtin(local_invocation_index) local_idx: u32, -) { -let _0 = local_idx == 0u; -if _0 { -var l_1_0: u32; -l_1_0 = info[1u]; -let slice_1_0_offset = 2u; -let slice_1_0_length = min(l_1_0, 3u) - 2u; -let slice_1_0_ptr = &output_0_global; -var l_1_1: u32; -var l_1_2: bool; -var l_1_3: f32; -l_1_1 = info[0u]; -l_1_2 = 0u < l_1_1; -l_1_3 = input_0_global[0u]; -let _1 = select(0f, l_1_3, l_1_2); -var l_1_4: u32; -var l_1_5: bool; -l_1_4 = slice_1_0_length; -l_1_5 = 0u < l_1_4; -if l_1_5 { -(*slice_1_0_ptr)[0u + slice_1_0_offset] = _1; -} -} -} diff --git a/crates/cubecl-wgpu/tests/unary_bench.wgsl b/crates/cubecl-wgpu/tests/unary_bench.wgsl deleted file mode 100644 index 5c6028062..000000000 --- a/crates/cubecl-wgpu/tests/unary_bench.wgsl +++ /dev/null @@ -1,132 +0,0 @@ -@group(0) -@binding(0) -var input_0_global: array>; - -@group(0) -@binding(1) -var input_1_global: array>; - -@group(0) -@binding(2) -var output_0_global: array>; - -@group(0) -@binding(3) -var info: array; - -const WORKGROUP_SIZE_X = 16u; -const WORKGROUP_SIZE_Y = 16u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(16, 16, 1) -fn execute_unary_kernel_f32( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { -let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; -let _0 = info[5u]; -let _1 = id < _0; -if _1 { - -for (var l_2_2: u32 = 0u; l_2_2 < 256u; l_2_2++) { -let _3 = l_2_2 % 2u; -let _4 = _3 == 0u; -if _4 { -var l_3_0: u32; -var l_3_1: bool; -var l_3_2: vec4; -l_3_0 = info[0u]; -l_3_1 = id < l_3_0; -l_3_2 = input_0_global[id]; -let _5 = vec4( -select(0f, l_3_2[0], l_3_1), -select(0f, l_3_2[1], l_3_1), -select(0f, l_3_2[2], l_3_1), -select(0f, l_3_2[3], l_3_1), -); -var l_3_3: u32; -var l_3_4: bool; -var l_3_5: vec4; -l_3_3 = info[1u]; -l_3_4 = id < l_3_3; -l_3_5 = input_1_global[id]; -let _6 = vec4( -select(0f, l_3_5[0], l_3_4), -select(0f, l_3_5[1], l_3_4), -select(0f, l_3_5[2], l_3_4), -select(0f, l_3_5[3], l_3_4), -); -let _7 = _5 * _6; -let _8 = cos(_7); -var l_3_6: u32; -var l_3_7: bool; -var l_3_8: vec4; -l_3_6 = info[2u]; -l_3_7 = id < l_3_6; -l_3_8 = output_0_global[id]; -let _9 = vec4( -select(0f, l_3_8[0], l_3_7), -select(0f, l_3_8[1], l_3_7), -select(0f, l_3_8[2], l_3_7), -select(0f, l_3_8[3], l_3_7), -); -let _10 = _9 - _8; -var l_3_9: u32; -var l_3_10: bool; -l_3_9 = info[2u]; -l_3_10 = id < l_3_9; -if l_3_10 { -output_0_global[id] = _10; -} -} else { -var l_3_0: u32; -var l_3_1: bool; -var l_3_2: vec4; -l_3_0 = info[0u]; -l_3_1 = id < l_3_0; -l_3_2 = input_0_global[id]; -let _11 = vec4( -select(0f, l_3_2[0], l_3_1), -select(0f, l_3_2[1], l_3_1), -select(0f, l_3_2[2], l_3_1), -select(0f, l_3_2[3], l_3_1), -); -var l_3_3: u32; -var l_3_4: bool; -var l_3_5: vec4; -l_3_3 = info[1u]; -l_3_4 = id < l_3_3; -l_3_5 = input_1_global[id]; -let _12 = vec4( -select(0f, l_3_5[0], l_3_4), -select(0f, l_3_5[1], l_3_4), -select(0f, l_3_5[2], l_3_4), -select(0f, l_3_5[3], l_3_4), -); -let _13 = _11 * _12; -let _14 = cos(_13); -var l_3_6: u32; -var l_3_7: bool; -var l_3_8: vec4; -l_3_6 = info[2u]; -l_3_7 = id < l_3_6; -l_3_8 = output_0_global[id]; -let _15 = vec4( -select(0f, l_3_8[0], l_3_7), -select(0f, l_3_8[1], l_3_7), -select(0f, l_3_8[2], l_3_7), -select(0f, l_3_8[3], l_3_7), -); -let _16 = _15 + _14; -var l_3_9: u32; -var l_3_10: bool; -l_3_9 = info[2u]; -l_3_10 = id < l_3_9; -if l_3_10 { -output_0_global[id] = _16; -} -} -} -} -} diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index dea28e07f..283f33570 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -27,7 +27,13 @@ impl Benchmark for MatmulBench { } fn name(&self) -> String { - format!("matmul-{}-{}-{:?}", R::name(), E::as_elem(), self.strategy).to_lowercase() + format!( + "matmul-{}-{}-{:?}", + R::name(), + E::as_elem_native_unchecked(), + self.strategy + ) + .to_lowercase() } fn sync(&self) { @@ -55,7 +61,13 @@ struct MatmulBench { fn run(device: R::Device, strategy: matmul::Strategy) { let client = R::client(&device); - for (b, m, n, k) in [(2, 4096, 4096, 4096), (2, 4096, 2040, 4096)] { + for (b, m, n, k) in [ + (1, 6144, 6144, 6144), + (1, 5000, 5000, 5000), + (2, 4096, 4096, 4096), + (16, 6144, 2048, 513), + (32, 256, 256, 256), + ] { let bench = MatmulBench:: { b, m, @@ -73,24 +85,29 @@ fn run(device: R::Device, strategy: matmul::Strategy) { } fn main() { - // #[cfg(feature = "wgpu")] - // { - // run::( - // Default::default(), - // matmul::Strategy::Tiling2D(Default::default()), - // ); - // run::(Default::default(), matmul::Strategy::PlaneMma); - // } - - #[cfg(feature = "wgpu-spirv")] + #[cfg(feature = "wgpu")] { - type R = cubecl::wgpu::WgpuRuntime; - - run::( + run::( Default::default(), matmul::Strategy::Tiling2D(Default::default()), ); - run::(Default::default(), matmul::Strategy::Accelerated); + run::(Default::default(), matmul::Strategy::PlaneMma); + } + + #[cfg(feature = "wgpu-spirv")] + { + type R = cubecl::wgpu::WgpuRuntime; + use half::f16; + + run::(Default::default(), matmul::Strategy::Standard); + run::(Default::default(), matmul::Strategy::Specialized); + run::(Default::default(), matmul::Strategy::Pipelined); + run::(Default::default(), matmul::Strategy::Standard); + run::(Default::default(), matmul::Strategy::Specialized); + run::(Default::default(), matmul::Strategy::Pipelined); + run::(Default::default(), matmul::Strategy::Standard); + run::(Default::default(), matmul::Strategy::Specialized); + run::(Default::default(), matmul::Strategy::Pipelined); } #[cfg(all(feature = "hip", target_os = "linux"))] @@ -107,7 +124,7 @@ fn main() { // CmmaOld // run:: f32>(Default::default(), matmul::Strategy::CmmaOld(Default::default())); // Accelerated - run::(Default::default(), matmul::Strategy::Accelerated); + run::(Default::default(), matmul::Strategy::Standard); // Half-precision ---------------------------------------------------- // Tiling2D run::( @@ -119,23 +136,21 @@ fn main() { // CmmaOld: OOM // run::(Default::default(), matmul::Strategy::CmmaOld(Default::default())); // Accelerated - run::( - Default::default(), - matmul::Strategy::Accelerated, - ); + run::(Default::default(), matmul::Strategy::Standard); } #[cfg(feature = "cuda")] { - run::( - Default::default(), - matmul::Strategy::Tiling2D(Default::default()), - ); - run::(Default::default(), matmul::Strategy::Accelerated); - run::(Default::default(), matmul::Strategy::Accelerated); - run::( - Default::default(), - matmul::Strategy::Accelerated, - ); + use half::f16; + + run::(Default::default(), matmul::Strategy::Standard); + run::(Default::default(), matmul::Strategy::Specialized); + run::(Default::default(), matmul::Strategy::Pipelined); + run::(Default::default(), matmul::Strategy::Standard); + run::(Default::default(), matmul::Strategy::Specialized); + run::(Default::default(), matmul::Strategy::Pipelined); + run::(Default::default(), matmul::Strategy::Standard); + run::(Default::default(), matmul::Strategy::Specialized); + run::(Default::default(), matmul::Strategy::Pipelined); } } diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs index d2152b62b..7d2b46afb 100644 --- a/crates/cubecl/benches/unary.rs +++ b/crates/cubecl/benches/unary.rs @@ -55,7 +55,7 @@ impl Benchmark for UnaryBench { format!( "unary-{}-{}-{:?}", R::name(), - E::as_elem(), + E::as_elem_native_unchecked(), self.vectorization ) .to_lowercase()