diff --git a/README.md b/README.md index e32e9a836..567cfbf8c 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,21 @@ This is just the beginning. We plan to include more utilities such as convolutions, random number generation, fast Fourier transforms, and other essential algorithms. We are a small team also building [Burn](https://burn.dev), so don't hesitate to contribute and port algorithms; it can help more than you would imagine! +## How it works + +CubeCL leverages Rust's proc macro system in a unique two-step process: + +1. Parsing: The proc macro parses the GPU kernel code using the syn crate. +2. Expansion: Instead of immediately generating an Intermediate Representation (IR), the macro generates a new Rust function. + +The generated function, semantically similar to the original, is responsible for creating the IR when called. +This approach differs from traditional compilers, which typically generate IR directly after parsing. +Our method enables several key features: + +- **Comptime**: By not transforming the original code, it becomes remarkably easy to integrate compile-time optimizations. +- **Automatic Vectorization**: By simply vectorizing the inputs of a CubeCL function, we can determine the vectorization factor of each intermediate variable during the expansion. +- **Rust Integration**: The generated code remains valid Rust code, allowing it to be bundled without any dependency on the specific runtime. + ## Design CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size. diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index f05c8b288..54e733b64 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -80,9 +80,9 @@ impl KernelLauncher { self, cube_count: CubeCount, kernel: K, - client: ComputeClient, + client: &ComputeClient, ) { - let bindings = self.into_bindings(&client); + let bindings = self.into_bindings(client); let kernel = Box::new(KernelTask::::new(kernel)); diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 520748a1f..7e869d1ca 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -1,4 +1,4 @@ -use super::{Bool, Numeric, UInt, Vectorized, F32, F64, I32, I64}; +use super::{Bool, CubePrimitive, Numeric, UInt, Vectorized, F32, F64, I32, I64}; use crate::{ ir::{ConstantScalarValue, Elem, Item, Operator, Variable, Vectorization}, prelude::{index_assign, init_expand, CubeContext, KernelBuilder, KernelLauncher}, @@ -200,10 +200,11 @@ impl From> for ExpandElement { } } -impl ExpandElementTyped { +impl ExpandElementTyped { /// Create an [ExpandElementTyped] from a value that is normaly a literal. pub fn from_lit>(lit: L) -> Self { let variable: Variable = lit.into(); + let variable = T::as_elem().from_constant(variable); ExpandElementTyped::new(ExpandElement::Plain(variable)) } diff --git a/crates/cubecl-core/src/ir/kernel.rs b/crates/cubecl-core/src/ir/kernel.rs index 215fdc889..5133a79ef 100644 --- a/crates/cubecl-core/src/ir/kernel.rs +++ b/crates/cubecl-core/src/ir/kernel.rs @@ -112,6 +112,23 @@ impl Elem { ConstantScalarValue::Bool(val) => self.constant_from_bool(val), } } + /// Get the size in bytes. + pub fn size(&self) -> usize { + match self { + Elem::Float(kind) => match kind { + FloatKind::F16 => core::mem::size_of::(), + FloatKind::BF16 => core::mem::size_of::(), + FloatKind::F32 => core::mem::size_of::(), + FloatKind::F64 => core::mem::size_of::(), + }, + Elem::Int(kind) => match kind { + IntKind::I32 => core::mem::size_of::(), + IntKind::I64 => core::mem::size_of::(), + }, + Elem::UInt => core::mem::size_of::(), + Elem::Bool => core::mem::size_of::(), + } + } } impl From for Item { diff --git a/crates/cubecl-core/src/lib.rs b/crates/cubecl-core/src/lib.rs index d71c4c479..dfdedbe1c 100644 --- a/crates/cubecl-core/src/lib.rs +++ b/crates/cubecl-core/src/lib.rs @@ -26,6 +26,7 @@ pub use runtime::*; pub use cubecl_macros::cube; pub use cubecl_macros::CubeLaunch; pub use cubecl_macros::CubeType; +pub use cubecl_runtime::benchmark; /// An approximation of the subcube dimension. pub const SUBCUBE_DIM_APPROX: usize = 16; diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index 94439a6a3..c7e5961b2 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -66,7 +66,7 @@ pub fn test_simple_1(client: ComputeClient) { let out = client.empty(core::mem::size_of::() * 256); kernel_simple_1::launch::( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::new(16, 16, 1), ArrayArg::new(&lhs, 256), diff --git a/crates/cubecl-core/src/runtime_tests/launch.rs b/crates/cubecl-core/src/runtime_tests/launch.rs index b96924ff0..62bc50e2f 100644 --- a/crates/cubecl-core/src/runtime_tests/launch.rs +++ b/crates/cubecl-core/src/runtime_tests/launch.rs @@ -20,7 +20,7 @@ pub fn test_kernel_with_generics(client: ComputeClient( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::default(), ArrayArg::new(&handle, 2), @@ -36,7 +36,7 @@ pub fn test_kernel_without_generics(client: ComputeClient( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::default(), ArrayArg::new(&handle, 2), diff --git a/crates/cubecl-core/src/runtime_tests/slice.rs b/crates/cubecl-core/src/runtime_tests/slice.rs index 335858b09..0facb53cb 100644 --- a/crates/cubecl-core/src/runtime_tests/slice.rs +++ b/crates/cubecl-core/src/runtime_tests/slice.rs @@ -31,7 +31,7 @@ pub fn test_slice_select(client: ComputeClient()); slice_select::launch::( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), ArrayArg::new(&input, 5), @@ -49,7 +49,7 @@ pub fn test_slice_len(client: ComputeClient) let output = client.empty(core::mem::size_of::()); slice_len::launch::( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), ArrayArg::new(&input, 5), @@ -67,7 +67,7 @@ pub fn test_slice_assign(client: ComputeClient( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), ArrayArg::new(&input, 5), diff --git a/crates/cubecl-core/src/runtime_tests/subcube.rs b/crates/cubecl-core/src/runtime_tests/subcube.rs index d01507071..7fc50b2d8 100644 --- a/crates/cubecl-core/src/runtime_tests/subcube.rs +++ b/crates/cubecl-core/src/runtime_tests/subcube.rs @@ -50,7 +50,7 @@ pub fn test_subcube_sum( &[17.0, 5.0, 7.0, 1.0], client.clone(), |cube_count, cube_dim, handle| { - kernel_sum::launch::(client.clone(), cube_count, cube_dim, handle) + kernel_sum::launch::(&client, cube_count, cube_dim, handle) }, ); } @@ -63,7 +63,7 @@ pub fn test_subcube_prod( &[140.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_prod::launch::(client.clone(), cube_dim, settings, handle) + kernel_prod::launch::(&client, cube_dim, settings, handle) }, ); } @@ -75,7 +75,7 @@ pub fn test_subcube_max( &[7.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_max::launch::(client.clone(), cube_dim, settings, handle) + kernel_max::launch::(&client, cube_dim, settings, handle) }, ); } @@ -88,7 +88,7 @@ pub fn test_subcube_min( &[1.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_min::launch::(client.clone(), cube_dim, settings, handle) + kernel_min::launch::(&client, cube_dim, settings, handle) }, ); } diff --git a/crates/cubecl-core/tests/frontend/tensor.rs b/crates/cubecl-core/tests/frontend/tensor.rs index 898a171d2..d7d905bdb 100644 --- a/crates/cubecl-core/tests/frontend/tensor.rs +++ b/crates/cubecl-core/tests/frontend/tensor.rs @@ -37,8 +37,8 @@ mod tests { let y = scope.create_local(Item::new(UInt::as_elem())); let z = scope.create_local(Item::new(UInt::as_elem())); - cpa!(&mut scope, x = shape(input, 1)); - cpa!(&mut scope, y = stride(input, 1)); + cpa!(&mut scope, x = shape(input, 1u32)); + cpa!(&mut scope, y = stride(input, 1u32)); cpa!(&mut scope, z = len(input)); scope.operations diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 555644c99..ded8b3f02 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -16,6 +16,8 @@ pub struct CudaCompiler { absolute_idx: (bool, bool, bool), wrap_size_checked: bool, wmma: bool, + bf16: bool, + f16: bool, shape: bool, stride: bool, num_inputs: usize, @@ -31,7 +33,7 @@ impl Compiler for CudaCompiler { } fn elem_size(elem: gpu::Elem) -> usize { - Self::compile_elem(elem).size() + elem.size() } fn max_shared_memory_size() -> usize { @@ -46,6 +48,22 @@ impl CudaCompiler { self.num_outputs = value.outputs.len(); let instructions = self.compile_scope(&mut value.body); + let inputs = value + .inputs + .into_iter() + .map(|b| self.compile_binding(b)) + .collect(); + let outputs = value + .outputs + .into_iter() + .map(|b| self.compile_binding(b)) + .collect(); + let named = value + .named + .into_iter() + .map(|(name, binding)| (name, self.compile_binding(binding))) + .collect(); + let body = super::Body { instructions, stride: true, @@ -60,24 +78,14 @@ impl CudaCompiler { }; super::ComputeKernel { - inputs: value - .inputs - .into_iter() - .map(Self::compile_binding) - .collect(), - outputs: value - .outputs - .into_iter() - .map(Self::compile_binding) - .collect(), - named: value - .named - .into_iter() - .map(|(name, binding)| (name, Self::compile_binding(binding))) - .collect(), + inputs, + outputs, + named, cube_dim: value.cube_dim, body, wmma_activated: self.wmma, + bf16: self.bf16, + f16: self.f16, } } @@ -187,7 +195,8 @@ impl CudaCompiler { output: self.compile_variable(output), frag: self.compile_variable(mat), stride: self.compile_variable(stride), - layout: Self::compile_matrix_layout(layout) + layout: self + .compile_matrix_layout(layout) .expect("Layout required for store instruction"), }), } @@ -364,7 +373,7 @@ impl CudaCompiler { gpu::Elem::Bool => ConstantScalarValue::Bool(true), }; Instruction::Div(super::BinaryInstruction { - lhs: super::Variable::ConstantScalar(lhs, Self::compile_elem(elem)), + lhs: super::Variable::ConstantScalar(lhs, self.compile_elem(elem)), rhs: self.compile_variable(op.input), out: self.compile_variable(op.out), }) @@ -399,34 +408,34 @@ impl CudaCompiler { fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable { match value { gpu::Variable::GlobalInputArray { id, item } => { - super::Variable::GlobalInputArray(id, Self::compile_item(item)) + super::Variable::GlobalInputArray(id, self.compile_item(item)) } gpu::Variable::GlobalScalar { id, elem } => { - super::Variable::GlobalScalar(id, Self::compile_elem(elem), elem) + super::Variable::GlobalScalar(id, self.compile_elem(elem), elem) } gpu::Variable::Local { id, item, depth } => super::Variable::Local { id, - item: Self::compile_item(item), + item: self.compile_item(item), depth, }, gpu::Variable::Slice { id, item, depth } => super::Variable::Slice { id, - item: Self::compile_item(item), + item: self.compile_item(item), depth, }, gpu::Variable::LocalScalar { id, elem, depth } => super::Variable::LocalScalar { id, - elem: Self::compile_elem(elem), + elem: self.compile_elem(elem), depth, }, gpu::Variable::GlobalOutputArray { id, item } => { - super::Variable::GlobalOutputArray(id, Self::compile_item(item)) + super::Variable::GlobalOutputArray(id, self.compile_item(item)) } gpu::Variable::ConstantScalar(value) => { - super::Variable::ConstantScalar(value, Self::compile_elem(value.elem())) + super::Variable::ConstantScalar(value, self.compile_elem(value.elem())) } gpu::Variable::SharedMemory { id, item, length } => { - let item = Self::compile_item(item); + let item = self.compile_item(item); if !self.shared_memories.iter().any(|s| s.index == id) { self.shared_memories .push(super::SharedMemory::new(id, item, length)); @@ -475,7 +484,7 @@ impl CudaCompiler { depth, length, } => { - let item = Self::compile_item(item); + let item = self.compile_item(item); if !self .local_arrays .iter() @@ -494,25 +503,25 @@ impl CudaCompiler { self.wmma = true; super::Variable::WmmaFragment { id, - frag: Self::compile_matrix(mat), + frag: self.compile_matrix(mat), depth, } } } } - fn compile_matrix(matrix: gpu::Matrix) -> super::Fragment { + fn compile_matrix(&mut self, matrix: gpu::Matrix) -> super::Fragment { super::Fragment { - ident: Self::compile_matrix_ident(matrix.ident), + ident: self.compile_matrix_ident(matrix.ident), m: matrix.m, n: matrix.n, k: matrix.k, - elem: Self::compile_elem(matrix.elem), - layout: Self::compile_matrix_layout(matrix.layout), + elem: self.compile_elem(matrix.elem), + layout: self.compile_matrix_layout(matrix.layout), } } - fn compile_matrix_ident(ident: gpu::MatrixIdent) -> super::FragmentIdent { + fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> super::FragmentIdent { match ident { gpu::MatrixIdent::A => super::FragmentIdent::A, gpu::MatrixIdent::B => super::FragmentIdent::B, @@ -520,7 +529,10 @@ impl CudaCompiler { } } - fn compile_matrix_layout(layout: gpu::MatrixLayout) -> Option { + fn compile_matrix_layout( + &mut self, + layout: gpu::MatrixLayout, + ) -> Option { match layout { gpu::MatrixLayout::ColMajor => Some(super::FragmentLayout::ColMajor), gpu::MatrixLayout::RowMajor => Some(super::FragmentLayout::RowMajor), @@ -528,28 +540,34 @@ impl CudaCompiler { } } - fn compile_binding(binding: gpu::Binding) -> super::Binding { + fn compile_binding(&mut self, binding: gpu::Binding) -> super::Binding { super::Binding { - item: Self::compile_item(binding.item), + item: self.compile_item(binding.item), size: binding.size, } } - fn compile_item(item: gpu::Item) -> super::Item { + fn compile_item(&mut self, item: gpu::Item) -> super::Item { match item.vectorization { - 4 => super::Item::Vec4(Self::compile_elem(item.elem)), - 3 => super::Item::Vec3(Self::compile_elem(item.elem)), - 2 => super::Item::Vec2(Self::compile_elem(item.elem)), - 1 => super::Item::Scalar(Self::compile_elem(item.elem)), + 4 => super::Item::Vec4(self.compile_elem(item.elem)), + 3 => super::Item::Vec3(self.compile_elem(item.elem)), + 2 => super::Item::Vec2(self.compile_elem(item.elem)), + 1 => super::Item::Scalar(self.compile_elem(item.elem)), _ => panic!("Vectorization factor unsupported {:?}", item.vectorization), } } - fn compile_elem(value: gpu::Elem) -> super::Elem { + fn compile_elem(&mut self, value: gpu::Elem) -> super::Elem { match value { gpu::Elem::Float(kind) => match kind { - gpu::FloatKind::F16 => super::Elem::F16, - gpu::FloatKind::BF16 => super::Elem::BF16, + gpu::FloatKind::F16 => { + self.f16 = true; + super::Elem::F16 + } + gpu::FloatKind::BF16 => { + self.bf16 = true; + super::Elem::BF16 + } gpu::FloatKind::F32 => super::Elem::F32, gpu::FloatKind::F64 => panic!("f64 isn't supported yet"), }, diff --git a/crates/cubecl-cuda/src/compiler/element.rs b/crates/cubecl-cuda/src/compiler/element.rs index ddfbd941e..ff797a24b 100644 --- a/crates/cubecl-cuda/src/compiler/element.rs +++ b/crates/cubecl-cuda/src/compiler/element.rs @@ -25,7 +25,7 @@ pub enum Item { impl Display for Elem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Elem::F16 => f.write_str("half"), + Elem::F16 => f.write_str("__half"), Elem::F32 => f.write_str("float"), Elem::BF16 => f.write_str("__nv_bfloat16"), Elem::I32 => f.write_str("int"), @@ -44,7 +44,7 @@ impl Display for Item { Elem::U32 => f.write_str("uint4"), Elem::Bool => f.write_str("bool4"), Elem::BF16 => f.write_str("__nv_bfloat164"), - Elem::F16 => f.write_str("half4"), + Elem::F16 => f.write_str("__half4"), }, Item::Vec3(elem) => match elem { Elem::F32 => f.write_str("float3"), @@ -52,7 +52,7 @@ impl Display for Item { Elem::U32 => f.write_str("uint3"), Elem::Bool => f.write_str("bool3"), Elem::BF16 => f.write_str("__nv_bfloat164"), - Elem::F16 => f.write_str("half3"), + Elem::F16 => f.write_str("__half3"), }, Item::Vec2(elem) => match elem { Elem::F32 => f.write_str("float2"), @@ -60,7 +60,7 @@ impl Display for Item { Elem::U32 => f.write_str("uint2"), Elem::Bool => f.write_str("bool2"), Elem::BF16 => f.write_str("__nv_bfloat162"), - Elem::F16 => f.write_str("half2"), + Elem::F16 => f.write_str("__half2"), }, Item::Scalar(elem) => f.write_fmt(format_args!("{elem}")), } @@ -194,13 +194,13 @@ impl Display for Variable { }, ConstantScalarValue::Float(val, kind) => match kind { gpu::FloatKind::F16 => { - f.write_fmt(format_args!("{elem}({})", half::f16::from_f64(*val))) + f.write_fmt(format_args!("{elem}({:?})", half::f16::from_f64(*val))) } gpu::FloatKind::BF16 => { - f.write_fmt(format_args!("{elem}({})", half::bf16::from_f64(*val))) + f.write_fmt(format_args!("{elem}({:?})", half::bf16::from_f64(*val))) } - gpu::FloatKind::F32 => f.write_fmt(format_args!("{elem}({})", *val as f32)), - gpu::FloatKind::F64 => f.write_fmt(format_args!("{elem}({})", { *val })), + gpu::FloatKind::F32 => f.write_fmt(format_args!("{elem}({:?})", *val as f32)), + gpu::FloatKind::F64 => f.write_fmt(format_args!("{elem}({:?})", { *val })), }, ConstantScalarValue::UInt(val) => { f.write_fmt(format_args!("{elem}({})", *val as u32)) diff --git a/crates/cubecl-cuda/src/compiler/kernel.rs b/crates/cubecl-cuda/src/compiler/kernel.rs index a5e8d73ef..0251b7945 100644 --- a/crates/cubecl-cuda/src/compiler/kernel.rs +++ b/crates/cubecl-cuda/src/compiler/kernel.rs @@ -48,6 +48,8 @@ pub struct ComputeKernel { pub cube_dim: CubeDim, pub body: Body, pub wmma_activated: bool, + pub bf16: bool, + pub f16: bool, } impl CompilerRepresentation for ComputeKernel { @@ -72,9 +74,33 @@ impl CompilerRepresentation for ComputeKernel { impl Display for ComputeKernel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //if self.wmma_activated { - f.write_str("#include \nusing namespace nvcuda;\n")?; - //} + if self.wmma_activated { + f.write_str("#include \n")?; + } + if self.bf16 { + f.write_str("#include \n")?; + } + + if self.f16 { + f.write_str("#include \n")?; + } + + if self.wmma_activated { + f.write_str("using namespace nvcuda;\n")?; + } + + if self.f16 { + f.write_str( + " +extern \"C\" struct __half4 { + __half x; + __half y; + __half z; + __half w; +}; +", + )?; + } f.write_fmt(format_args!( " @@ -87,6 +113,7 @@ extern \"C\" struct bool4 {{ bool w; }}; + extern \"C\" __global__ void kernel( ", ))?; diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index 412b1fd77..d2a19b6a3 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -80,7 +80,7 @@ pub fn matmul_cmma( ); cmma_kernel::launch::( - client, + &client, cube_count, cube_dim, TensorArg::vectorized(lhs_vectorization, &lhs.handle, &lhs.strides, &lhs.shape), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index 7d2366e59..4f8161cca 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -84,7 +84,7 @@ pub fn compute_loop_k_test(device: &R::Device) { }; compute_loop_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -159,7 +159,7 @@ pub fn compute_loop_warp_test(device: &R::Device) { }; compute_loop_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -263,7 +263,7 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De }; compute_loop_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index cd73fb54c..ab1542b5c 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -93,7 +93,7 @@ pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -149,7 +149,7 @@ pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { }; load_rhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -205,7 +205,7 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -266,7 +266,7 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -325,7 +325,7 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -384,7 +384,7 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -442,7 +442,7 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { }; load_rhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -503,7 +503,7 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -563,7 +563,7 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { }; load_rhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -626,7 +626,7 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -689,7 +689,7 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { }; load_rhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -749,7 +749,7 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { }; load_lhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -809,7 +809,7 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { }; load_rhs_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index 61aeca00c..e656849a3 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -60,7 +60,7 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -125,7 +125,7 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -200,7 +200,7 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -270,7 +270,7 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -340,7 +340,7 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -406,7 +406,7 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -521,7 +521,7 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) }; write_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index 42d91bac9..a67351483 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -51,7 +51,7 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); tile_outer_product_test::launch::( - client.clone(), + &client, cube_count, cube_dim, ArrayArg::new(®ister_m, 4), @@ -125,7 +125,7 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); tile_outer_product_test::launch::( - client.clone(), + &client, cube_count, cube_dim, ArrayArg::new(®ister_m, 4), @@ -152,7 +152,7 @@ pub fn compute_loop_unit_test(device: &R::Device) { let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); compute_loop_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), @@ -183,7 +183,7 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { let config = make_tiling2d_config(4, 8, 4); compute_loop_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index 567512426..02bc37589 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -222,7 +222,7 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); load_tensor_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -254,7 +254,7 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic let config = make_tiling2d_config(5, 1, 1); load_tensor_multiple_tiles_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized( @@ -288,7 +288,7 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); load_tensor_multiple_tiles_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -318,7 +318,7 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 16); load_tensor_multiple_tiles_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -348,7 +348,7 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 16, 16); load_tensor_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), @@ -379,7 +379,7 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); load_tensor_multiple_tiles_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), @@ -409,7 +409,7 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); load_tensor_multiple_tiles_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), @@ -439,7 +439,7 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); load_tensor_permuted_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -471,7 +471,7 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { let config = make_tiling2d_config(m, k, 8); load_tensor_permuted_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -502,7 +502,7 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); load_tensor_permuted_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), @@ -534,7 +534,7 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic let config = make_tiling2d_config(8, k, n); load_tensor_permuted_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index a4fecfc72..fd61da094 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -66,7 +66,7 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { let config = make_tiling2d_config(6, 8, 8); write_to_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), @@ -92,7 +92,7 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 4); write_to_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), @@ -118,7 +118,7 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & let config = make_tiling2d_config(8, 8, 8); write_to_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), @@ -146,7 +146,7 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); write_to_output_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), @@ -174,7 +174,7 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De let config = make_tiling2d_config(5, 8, 1); write_results_to_output_out_of_bounds_test::launch::( - R::client(device), + &R::client(device), cube_count, cube_dim, TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape), diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index 3b448fef6..9ac0a4be4 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -70,7 +70,7 @@ pub fn matmul_tiling_2d( let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); tiling2d_cube_kernel::launch::( - client, + &client, cube_count, cube_dim, TensorArg::vectorized(lhs_vectorization, &lhs.handle, &lhs.strides, &lhs.shape), diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index 2a2367b87..43f96781c 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -1,5 +1,8 @@ +use cubecl_core::calculate_cube_count_elemwise; use cubecl_core::prelude::*; +use cubecl_core::tensor_vectorization_factor; use cubecl_core::Runtime; +use cubecl_core::SUBCUBE_DIM_APPROX; use cubecl_runtime::server::Handle; use std::marker::PhantomData; @@ -56,17 +59,19 @@ where R: Runtime, E: CubePrimitive, { + /// Create a new tensor. + pub fn new(shape: Vec, strides: Vec, handle: Handle) -> Self { + Self { + shape, + strides, + handle, + elem: PhantomData, + } + } + /// Create a new tensor with a contiguous memory layout. pub fn new_contiguous(shape: Vec, handle: Handle) -> Self { - let d = shape.len(); - let mut strides = Vec::with_capacity(d); - - let mut current = 1; - shape.iter().enumerate().rev().for_each(|(_, val)| { - strides.push(current); - current *= val; - }); - strides.reverse(); + let strides = Self::contiguous_strides(&shape); Self { handle, @@ -93,4 +98,58 @@ where pub(crate) fn rank(&self) -> usize { self.shape.len() } + + fn contiguous_strides(shape: &[usize]) -> Vec { + let mut strides = Vec::with_capacity(shape.len()); + + let mut current = 1; + shape.iter().enumerate().rev().for_each(|(_, val)| { + strides.push(current); + current *= val; + }); + strides.reverse(); + strides + } +} +impl TensorHandle +where + R: Runtime, + E: Numeric, +{ + pub fn zeros(client: ComputeClient, shape: Vec) -> Self { + let num_elements: usize = shape.iter().product(); + let size = E::as_elem().size(); + + let handle = client.empty(size * num_elements); + let strides = Self::contiguous_strides(&shape); + + let vectorization_factor = + tensor_vectorization_factor(&[4, 2], &shape, &strides, shape.len() - 1); + + let cube_count = calculate_cube_count_elemwise::( + num_elements / vectorization_factor as usize, + SUBCUBE_DIM_APPROX, + ); + + init::zeros_array::launch::( + &client, + cube_count, + CubeDim::default(), + ArrayArg::new(&handle, num_elements), + ); + + Self::new(shape, strides, handle) + } +} + +pub(crate) mod init { + use cubecl::prelude::*; + use cubecl_core as cubecl; + + #[cube(launch)] + pub fn zeros_array(output: &mut Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = C::from_int(0); + } + } } diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs index 958eca6a7..fbb4a283f 100644 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ b/crates/cubecl-macros/src/codegen_function/launch.rs @@ -472,7 +472,7 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { #[allow(clippy::too_many_arguments)] #[doc = #doc] pub fn launch #generics ( - client: ComputeClient, + client: &ComputeClient, cube_count: CubeCount, cube_dim: CubeDim, #inputs diff --git a/crates/cubecl-runtime/src/lib.rs b/crates/cubecl-runtime/src/lib.rs index 255328b1e..48689d0fc 100644 --- a/crates/cubecl-runtime/src/lib.rs +++ b/crates/cubecl-runtime/src/lib.rs @@ -27,3 +27,4 @@ pub mod storage; mod compute; pub use compute::*; +pub use cubecl_common::benchmark; diff --git a/crates/cubecl/Cargo.toml b/crates/cubecl/Cargo.toml index e4488a4ac..a75980cf0 100644 --- a/crates/cubecl/Cargo.toml +++ b/crates/cubecl/Cargo.toml @@ -19,7 +19,7 @@ version.workspace = true rust-version = "1.75" [features] -default = ["cubecl-wgpu?/default", "cubecl-cuda?/default"] +default = ["cubecl-wgpu?/default", "cubecl-cuda?/default", "linalg"] std = ["cubecl-wgpu?/std", "cubecl-cuda?/std"] template = ["cubecl-core/template"] linalg = ["dep:cubecl-linalg"] @@ -33,3 +33,7 @@ cubecl-core = { path = "../cubecl-core", version = "0.1.0", default-features = f cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.1.0", default-features = false, optional = true } cubecl-cuda = { path = "../cubecl-cuda", version = "0.1.0", default-features = false, optional = true } cubecl-linalg = { path = "../cubecl-linalg", version = "0.1.0", default-features = false, optional = true } + +[[bench]] +name = "matmul" +harness = false diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs new file mode 100644 index 000000000..b31a8886d --- /dev/null +++ b/crates/cubecl/benches/matmul.rs @@ -0,0 +1,87 @@ +use cubecl::prelude::*; +use std::marker::PhantomData; + +use cubecl::benchmark::Benchmark; +use cubecl::client::SyncType; +use cubecl::frontend::Float; +use cubecl_linalg::matmul::cmma::matmul_cmma; +use cubecl_linalg::matmul::tiling2d::matmul_tiling_2d; +use cubecl_linalg::tensor::TensorHandle; + +impl Benchmark for Tiling2dBench { + type Args = (TensorHandle, TensorHandle, TensorHandle); + + fn prepare(&self) -> Self::Args { + let (b, m, k, n) = (self.b, self.m, self.k, self.n); + let client = R::client(&self.device); + let lhs = TensorHandle::zeros(client.clone(), vec![b, m, k]); + let rhs = TensorHandle::zeros(client.clone(), vec![b, k, n]); + let out = TensorHandle::zeros(client.clone(), vec![b, m, n]); + + (lhs, rhs, out) + } + + fn execute(&self, (lhs, rhs, out): Self::Args) { + match self.kind { + MatmulKind::Tiling2d => { + matmul_tiling_2d(lhs, rhs, out, Default::default(), &self.device); + } + MatmulKind::Cmma => { + matmul_cmma(lhs, rhs, out, &self.device); + } + } + } + + fn name(&self) -> String { + let elem = E::as_elem(); + format!("tiling2d-{}-{:?}-{:?}", R::name(), elem, self.kind) + } + + fn sync(&self) { + let client = R::client(&self.device); + client.sync(SyncType::Wait); + } +} + +#[allow(dead_code)] +struct Tiling2dBench { + b: usize, + m: usize, + k: usize, + n: usize, + kind: MatmulKind, + device: R::Device, + _e: PhantomData, +} + +#[allow(dead_code)] +#[derive(Debug)] +enum MatmulKind { + Tiling2d, + Cmma, +} + +#[allow(dead_code)] +fn run(device: R::Device, kind: MatmulKind) { + let bench = Tiling2dBench:: { + b: 32, + m: 1024, + k: 1024, + n: 1024, + device, + kind, + _e: PhantomData, + }; + println!("{}", bench.name()); + println!("{}", bench.run()); +} + +fn main() { + #[cfg(feature = "wgpu")] + run::(Default::default(), MatmulKind::Tiling2d); + #[cfg(feature = "cuda")] + run::(Default::default(), MatmulKind::Tiling2d); + + #[cfg(feature = "cuda")] + run::(Default::default(), MatmulKind::Cmma); +} diff --git a/examples/gelu/Cargo.toml b/examples/gelu/Cargo.toml index 0acaf9c6d..8875481c7 100644 --- a/examples/gelu/Cargo.toml +++ b/examples/gelu/Cargo.toml @@ -13,3 +13,4 @@ cuda = ["cubecl/cuda"] [dependencies] cubecl = { path = "../../crates/cubecl", version = "0.1.0" } +half = { workspace = true } diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 6917c8904..3d4e0c552 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -15,15 +15,13 @@ fn gelu_scalar(x: F) -> F { pub fn launch(device: &R::Device) { let client = R::client(device); let input = &[-1., 0., 1., 5.]; - let output_handle = client.empty(input.len() * core::mem::size_of::()); - let input_handle = client.create(f32::as_bytes(input)); gelu_array::launch::( - client.clone(), + &client, CubeCount::Static(1, 1, 1), CubeDim::new(input.len() as u32, 1, 1), - ArrayArg::new(&input_handle, input.len()), + ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()), ArrayArg::new(&output_handle, input.len()), );