From 30d090c8b725e83a829a6459d9c09578d9a4ba68 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Thu, 8 Aug 2024 14:52:02 -0400 Subject: [PATCH] Unchecked Execution Mode (#51) --- README.md | 24 +- crates/cubecl-core/src/codegen/compiler.rs | 3 +- crates/cubecl-core/src/compute/kernel.rs | 19 +- crates/cubecl-core/src/compute/launcher.rs | 61 ++- .../cubecl-core/src/frontend/element/array.rs | 31 +- .../src/frontend/element/tensor.rs | 56 ++- crates/cubecl-core/src/id.rs | 8 + crates/cubecl-core/src/ir/processing.rs | 12 +- crates/cubecl-core/src/runtime.rs | 1 + .../cubecl-core/src/runtime_tests/assign.rs | 2 +- crates/cubecl-core/src/runtime_tests/cmma.rs | 18 +- .../cubecl-core/src/runtime_tests/launch.rs | 4 +- .../cubecl-core/src/runtime_tests/sequence.rs | 4 +- crates/cubecl-core/src/runtime_tests/slice.rs | 48 +- .../cubecl-core/src/runtime_tests/subcube.rs | 12 +- .../cubecl-core/src/runtime_tests/topology.rs | 16 +- crates/cubecl-cuda/src/compiler/base.rs | 202 ++++++-- crates/cubecl-cuda/src/compiler/binary.rs | 94 +++- crates/cubecl-cuda/src/compiler/element.rs | 9 + crates/cubecl-cuda/src/compute/server.rs | 14 +- crates/cubecl-linalg/src/matmul/cmma/base.rs | 2 +- .../cubecl-linalg/src/matmul/cmma/launch.rs | 20 +- .../src/matmul/tests/cmma/compute_loop.rs | 80 +-- .../matmul/tests/cmma/load_shared_memory.rs | 472 +++++++++--------- .../src/matmul/tests/cmma/write_output.rs | 156 +++--- .../src/matmul/tests/tiling2d/compute_loop.rs | 100 ++-- .../tests/tiling2d/load_shared_memory.rs | 282 ++++++----- .../src/matmul/tests/tiling2d/write_output.rs | 94 ++-- .../cubecl-linalg/src/matmul/tiling2d/base.rs | 2 +- .../src/matmul/tiling2d/launch.rs | 20 +- crates/cubecl-linalg/src/tensor/base.rs | 25 +- crates/cubecl-linalg/src/tensor/contiguous.rs | 14 +- .../src/codegen_function/launch.rs | 35 +- crates/cubecl-macros/src/lib.rs | 33 +- .../src/{compute.rs => base.rs} | 10 + crates/cubecl-runtime/src/channel/base.rs | 8 +- crates/cubecl-runtime/src/channel/cell.rs | 6 +- crates/cubecl-runtime/src/channel/mpsc.rs | 14 +- crates/cubecl-runtime/src/channel/mutex.rs | 6 +- crates/cubecl-runtime/src/client.rs | 21 +- crates/cubecl-runtime/src/lib.rs | 4 +- crates/cubecl-runtime/src/server.rs | 8 +- crates/cubecl-runtime/tests/dummy/server.rs | 4 +- crates/cubecl-wgpu/Cargo.toml | 1 - .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 3 +- crates/cubecl-wgpu/src/compute/server.rs | 38 +- crates/cubecl/benches/unary.rs | 6 +- crates/cubecl/src/runtime.rs | 0 examples/gelu/src/lib.rs | 18 +- 49 files changed, 1271 insertions(+), 849 deletions(-) rename crates/cubecl-runtime/src/{compute.rs => base.rs} (92%) delete mode 100644 crates/cubecl/src/runtime.rs diff --git a/README.md b/README.md index dcab44293..3ea54fbce 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Simply annotate functions with the `cube` attribute to indicate that they should ```rust use cubecl::prelude::*; -#[cube(launch)] +#[cube(launch_unchecked)] fn gelu_array(input: &Array, output: &mut Array) { if ABSOLUTE_POS < input.len() { output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); @@ -38,24 +38,26 @@ fn gelu_array(input: &Array, output: &mut Array) { fn gelu_scalar(x: F) -> F { x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0 } - ``` -You can then launch the kernel using the autogenerated `gelu_array::launch` function. +You can then launch the kernel using the autogenerated `gelu_array::launch_unchecked` function. ```rust fn launch(device: &R::Device) { let client = R::client(device); let input = &[-1., 0., 1., 5.]; let output_handle = client.empty(input.len() * core::mem::size_of::()); - - gelu_array::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()), - ArrayArg::new(&output_handle, input.len()), - ); + let input_handle = client.create(f32::as_bytes(input)); + + unsafe { + gelu_array::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32, 1, 1), + ArrayArg::from_raw_parts(&input_handle, input.len(), 1), + ArrayArg::from_raw_parts(&output_handle, input.len(), 1), + ) + }; let bytes = client.read(output_handle.binding()); let output = f32::from_bytes(&bytes); diff --git a/crates/cubecl-core/src/codegen/compiler.rs b/crates/cubecl-core/src/codegen/compiler.rs index 4f1e45105..2370dd79c 100644 --- a/crates/cubecl-core/src/codegen/compiler.rs +++ b/crates/cubecl-core/src/codegen/compiler.rs @@ -1,4 +1,5 @@ use crate::ir::{Elem, KernelDefinition}; +use cubecl_runtime::ExecutionMode; use std::fmt::Display; /// Trait for compiled code representation @@ -13,7 +14,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { type Representation: CompilerRepresentation; /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation. - fn compile(kernel: KernelDefinition) -> Self::Representation; + fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: Elem) -> usize; /// The maximal size of a shared memory diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index 2658d25d9..3e3175631 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -5,7 +5,10 @@ use std::{ use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId}; use alloc::sync::Arc; -use cubecl_runtime::server::{Binding, ComputeServer}; +use cubecl_runtime::{ + server::{Binding, ComputeServer}, + ExecutionMode, +}; /// A kernel, compiled in the target language pub struct CompiledKernel { @@ -157,7 +160,7 @@ pub trait CubeTask: Send + Sync { /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> KernelId; /// Compile the kernel into source - fn compile(&self) -> CompiledKernel; + fn compile(&self, mode: ExecutionMode) -> CompiledKernel; } /// Wraps a [kernel](Kernel) to create a [cube task](CubeTask). @@ -168,10 +171,10 @@ pub struct KernelTask { } impl CubeTask for KernelTask { - fn compile(&self) -> CompiledKernel { + fn compile(&self, mode: ExecutionMode) -> CompiledKernel { let gpu_ir = self.kernel_definition.define(); let cube_dim = gpu_ir.cube_dim; - let lower_level_ir = C::compile(gpu_ir); + let lower_level_ir = C::compile(gpu_ir, mode); let shared_mem_bytes = lower_level_ir.shared_memory_size(); let source = lower_level_ir.to_string(); @@ -190,8 +193,8 @@ impl CubeTask for KernelTask { } impl CubeTask for Arc { - fn compile(&self) -> CompiledKernel { - self.as_ref().compile() + fn compile(&self, mode: ExecutionMode) -> CompiledKernel { + self.as_ref().compile(mode) } fn id(&self) -> KernelId { @@ -200,8 +203,8 @@ impl CubeTask for Arc { } impl CubeTask for Box { - fn compile(&self) -> CompiledKernel { - self.as_ref().compile() + fn compile(&self, mode: ExecutionMode) -> CompiledKernel { + self.as_ref().compile(mode) } fn id(&self) -> KernelId { diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 54e733b64..3038a007a 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -89,6 +89,24 @@ impl KernelLauncher { client.execute(kernel, cube_count, bindings); } + /// Launch the kernel without check bounds. + /// + /// # Safety + /// + /// Out-of-bounds reads and writes can happen. + pub unsafe fn launch_unchecked( + self, + cube_count: CubeCount, + kernel: K, + client: &ComputeClient, + ) { + let bindings = self.into_bindings(client); + + let kernel = Box::new(KernelTask::::new(kernel)); + + client.execute_unchecked(kernel, cube_count, bindings); + } + /// We need to create the bindings in the same order they are defined in the compilation step. /// /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed @@ -174,21 +192,16 @@ impl TensorState { bindings.push(tensor.handle.clone().binding()); - let old_rank = if metadata.is_empty() { + if metadata.is_empty() { let rank = tensor.strides.len() as u32; metadata.push(rank); - None } else if tensor.strides.len() > metadata[0] as usize { - let old_rank = metadata[0]; let rank = tensor.strides.len() as u32; - Self::adjust_rank(metadata, bindings.len(), rank); - Some(old_rank) - } else { - None - }; + Self::adjust_rank(metadata, bindings.len() - 1, rank); + } - Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata); - Self::register_shape(tensor.shape, old_rank, metadata); + Self::register_strides(tensor.strides, tensor.shape, None, metadata); + Self::register_shape(tensor.shape, None, metadata); if R::require_array_lengths() { let len = calculate_num_elems_dyn_rank(tensor.shape); @@ -200,6 +213,7 @@ impl TensorState { let old_rank = metadata[0] as usize; let rank_diff = rank as usize - old_rank; let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered); + updated_metadata.push(rank); for pos in 0..num_registered { let stride_index = (pos * old_rank * 2) + 1; @@ -228,19 +242,14 @@ impl TensorState { ) { let old_rank = if let Some(old_rank) = old_rank { let rank = output[0]; - let rank_diff = old_rank - rank; - let padded_strides = if rank_diff > 0 { - shape - .iter() - .take(old_rank as usize) - .map(|a| a.to_u32().unwrap()) - .sum::() - } else { - 0 - }; + let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize; + + if rank_diff > 0 { + let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::(); - for _ in 0..rank_diff { - output.push(padded_strides.to_u32().unwrap()); + for _ in 0..rank_diff { + output.push(padded_strides); + } } old_rank as usize @@ -256,10 +265,12 @@ impl TensorState { fn register_shape(shape: &[T], old_rank: Option, output: &mut Vec) { let old_rank = if let Some(old_rank) = old_rank { let rank = output[0]; - let rank_diff = rank - old_rank; + let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize; - for _ in 0..rank_diff { - output.push(1); + if rank_diff > 0 { + for _ in 0..rank_diff { + output.push(1); + } } old_rank as usize diff --git a/crates/cubecl-core/src/frontend/element/array.rs b/crates/cubecl-core/src/frontend/element/array.rs index b5e1a36fb..d3cad4bde 100644 --- a/crates/cubecl-core/src/frontend/element/array.rs +++ b/crates/cubecl-core/src/frontend/element/array.rs @@ -144,7 +144,7 @@ impl LaunchArgExpand for Array { /// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle). pub struct ArrayHandleRef<'a, R: Runtime> { pub handle: &'a cubecl_runtime::server::Handle, - pub length: [usize; 1], + pub(crate) length: [usize; 1], } pub enum ArrayArg<'a, R: Runtime> { @@ -205,35 +205,38 @@ impl<'a, R: Runtime> ArgSettings for ArrayArg<'a, R> { impl<'a, R: Runtime> ArrayArg<'a, R> { /// Create a new array argument. /// - /// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization - /// factor of 1. - pub fn new(handle: &'a cubecl_runtime::server::Handle, length: usize) -> Self { - ArrayArg::Handle { - handle: ArrayHandleRef::new(handle, length), - vectorization_factor: 1, - } - } - /// Create a new array argument specified with its vectorization factor. - pub fn vectorized( - vectorization_factor: u8, + /// # Safety + /// + /// Specifying the wrong lenght may lead to out-of-bounds reads and writes. + pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, length: usize, + vectorization_factor: u8, ) -> Self { ArrayArg::Handle { - handle: ArrayHandleRef::new(handle, length), + handle: ArrayHandleRef::from_raw_parts(handle, length), vectorization_factor, } } } impl<'a, R: Runtime> ArrayHandleRef<'a, R> { - pub fn new(handle: &'a cubecl_runtime::server::Handle, length: usize) -> Self { + /// Create a new array handle reference. + /// + /// # Safety + /// + /// Specifying the wrong lenght may lead to out-of-bounds reads and writes. + pub unsafe fn from_raw_parts( + handle: &'a cubecl_runtime::server::Handle, + length: usize, + ) -> Self { Self { handle, length: [length], } } + /// Return the handle as a tensor instead of an array. pub fn as_tensor(&self) -> TensorHandleRef<'_, R> { let shape = &self.length; diff --git a/crates/cubecl-core/src/frontend/element/tensor.rs b/crates/cubecl-core/src/frontend/element/tensor.rs index cb6e0dfeb..9ffce8e6d 100644 --- a/crates/cubecl-core/src/frontend/element/tensor.rs +++ b/crates/cubecl-core/src/frontend/element/tensor.rs @@ -52,13 +52,36 @@ impl LaunchArg for Tensor { /// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle), /// the strides and the shape. -#[derive(new)] pub struct TensorHandleRef<'a, R: Runtime> { pub handle: &'a cubecl_runtime::server::Handle, pub strides: &'a [usize], pub shape: &'a [usize], } +impl<'a, R: Runtime> TensorHandleRef<'a, R> { + /// Convert the handle into a [tensor argument](TensorArg). + pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + unsafe { TensorArg::from_raw_parts(self.handle, self.strides, self.shape, vectorisation) } + } + /// Create a handle from raw parts. + /// + /// # Safety + /// + /// If you provide wrong strides or shapes, it might create undefined behavior caused by + /// out-of-bounds reads and writes. + pub unsafe fn from_raw_parts( + handle: &'a cubecl_runtime::server::Handle, + strides: &'a [usize], + shape: &'a [usize], + ) -> Self { + Self { + handle, + strides, + shape, + } + } +} + /// Argument to be used for [tensors](Tensor) passed as arguments to kernels. pub enum TensorArg<'a, R: Runtime> { /// The tensor is passed with a tensor handle. @@ -76,32 +99,27 @@ pub enum TensorArg<'a, R: Runtime> { } impl<'a, R: Runtime> TensorArg<'a, R> { - /// Create a new tensor argument. + /// Create a new tensor argument specified with its vectorization factor. + /// + /// # Safety /// - /// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization - /// factor of 1. - pub fn new( + /// If you provide wrong strides or shapes, it might create undefined behavior caused by + /// out-of-bound reads and writes. + pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, strides: &'a [usize], shape: &'a [usize], - ) -> Self { - Self::Handle { - handle: TensorHandleRef::new(handle, strides, shape), - vectorization_factor: 1, - } - } - /// Create a new tensor argument specified with its vectorization factor. - pub fn vectorized( factor: u8, - handle: &'a cubecl_runtime::server::Handle, - strides: &'a [usize], - shape: &'a [usize], ) -> Self { - Self::Handle { - handle: TensorHandleRef::new(handle, strides, shape), - vectorization_factor: factor, + unsafe { + Self::Handle { + handle: TensorHandleRef::from_raw_parts(handle, strides, shape), + vectorization_factor: factor, + } } } + + /// Create an alias argument. pub fn alias(position: usize) -> Self { Self::Alias { input_pos: position, diff --git a/crates/cubecl-core/src/id.rs b/crates/cubecl-core/src/id.rs index bfe9132f8..b8607dcbe 100644 --- a/crates/cubecl-core/src/id.rs +++ b/crates/cubecl-core/src/id.rs @@ -1,3 +1,4 @@ +use cubecl_runtime::ExecutionMode; use std::any::{Any, TypeId}; use std::fmt::Display; use std::hash::{DefaultHasher, Hash, Hasher}; @@ -8,6 +9,7 @@ use std::sync::Arc; pub struct KernelId { type_id: core::any::TypeId, info: Option, + mode: Option, } impl Display for KernelId { @@ -25,6 +27,7 @@ impl KernelId { Self { type_id: core::any::TypeId::of::(), info: None, + mode: None, } } @@ -39,6 +42,11 @@ impl KernelId { self.info = Some(Info::new(info)); self } + + /// Set the [execution mode](ExecutionMode). + pub fn mode(&mut self, mode: ExecutionMode) { + self.mode = Some(mode); + } } /// Extra information diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index 12351820b..6a4a8e165 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -124,15 +124,15 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } - Operator::Index(op) => { - sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); - sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); - } Operator::Slice(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &op.out); sanitize_constant_scalar_ref_elem(&mut op.start, Elem::UInt); sanitize_constant_scalar_ref_elem(&mut op.end, Elem::UInt); } + Operator::Index(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); + sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); + } Operator::UncheckedIndex(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); @@ -142,8 +142,8 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } Operator::UncheckedIndexAssign(op) => { - sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); - sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); + sanitize_constant_scalar_ref_elem(&mut op.lhs, Elem::UInt); + sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } Operator::And(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs); diff --git a/crates/cubecl-core/src/runtime.rs b/crates/cubecl-core/src/runtime.rs index 21683dea9..b9a42cbf5 100644 --- a/crates/cubecl-core/src/runtime.rs +++ b/crates/cubecl-core/src/runtime.rs @@ -9,6 +9,7 @@ pub use cubecl_runtime::channel; pub use cubecl_runtime::client; pub use cubecl_runtime::server; pub use cubecl_runtime::tune; +pub use cubecl_runtime::ExecutionMode; /// Runtime for the CubeCL. pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { diff --git a/crates/cubecl-core/src/runtime_tests/assign.rs b/crates/cubecl-core/src/runtime_tests/assign.rs index 08dfd77f2..f9c81aae5 100644 --- a/crates/cubecl-core/src/runtime_tests/assign.rs +++ b/crates/cubecl-core/src/runtime_tests/assign.rs @@ -19,7 +19,7 @@ pub fn test_kernel_assign_scalar(client: ComputeClient(client: ComputeClient) { let rhs = client.create(f16::as_bytes(&rhs)); let out = client.empty(core::mem::size_of::() * 256); - kernel_simple_1::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(16, 16, 1), - ArrayArg::new(&lhs, 256), - ArrayArg::new(&rhs, 256), - ArrayArg::new(&out, 256), - ); + unsafe { + kernel_simple_1::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(16, 16, 1), + ArrayArg::from_raw_parts(&lhs, 256, 1), + ArrayArg::from_raw_parts(&rhs, 256, 1), + ArrayArg::from_raw_parts(&out, 256, 1), + ) + }; let actual = client.read(out.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-core/src/runtime_tests/launch.rs b/crates/cubecl-core/src/runtime_tests/launch.rs index 62bc50e2f..38c7d204c 100644 --- a/crates/cubecl-core/src/runtime_tests/launch.rs +++ b/crates/cubecl-core/src/runtime_tests/launch.rs @@ -23,7 +23,7 @@ pub fn test_kernel_with_generics(client: ComputeClient(client: ComputeClient(client: ComputeClient(client: ComputeClient(client: ComputeClient()); - slice_select::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); + unsafe { + slice_select::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); @@ -48,13 +50,15 @@ pub fn test_slice_len(client: ComputeClient) let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); let output = client.empty(core::mem::size_of::()); - slice_len::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); + unsafe { + slice_len::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; let actual = client.read(output.binding()); let actual = u32::from_bytes(&actual); @@ -66,13 +70,15 @@ pub fn test_slice_assign(client: ComputeClient( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); + unsafe { + slice_assign::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-core/src/runtime_tests/subcube.rs b/crates/cubecl-core/src/runtime_tests/subcube.rs index 7fc50b2d8..f9bbc0578 100644 --- a/crates/cubecl-core/src/runtime_tests/subcube.rs +++ b/crates/cubecl-core/src/runtime_tests/subcube.rs @@ -109,11 +109,13 @@ fn test_subcube_operation( let handle = client.create(f32::as_bytes(input)); let (shape, strides) = ([input.len()], [1]); - launch( - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - TensorArg::new(&handle, &strides, &shape), - ); + unsafe { + launch( + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32, 1, 1), + TensorArg::from_raw_parts(&handle, &strides, &shape, 1), + ); + } let actual = client.read(handle.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-core/src/runtime_tests/topology.rs b/crates/cubecl-core/src/runtime_tests/topology.rs index 1b2df08e3..cc9d687e7 100644 --- a/crates/cubecl-core/src/runtime_tests/topology.rs +++ b/crates/cubecl-core/src/runtime_tests/topology.rs @@ -22,13 +22,15 @@ pub fn test_kernel_topology_absolute_pos(client: ComputeClient()); let handle2 = client.empty(length as usize * core::mem::size_of::()); - kernel_absolute_pos::launch::( - &client, - CubeCount::Static(cube_count.0, cube_count.1, cube_count.2), - CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2), - ArrayArg::new(&handle1, length as usize), - ArrayArg::new(&handle2, length as usize), - ); + unsafe { + kernel_absolute_pos::launch::( + &client, + CubeCount::Static(cube_count.0, cube_count.1, cube_count.2), + CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2), + ArrayArg::from_raw_parts(&handle1, length as usize, 1), + ArrayArg::from_raw_parts(&handle2, length as usize, 1), + ) + }; let actual = client.read(handle1.binding()); let actual = u32::from_bytes(&actual); diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 5966a1ba7..eada5bddc 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -4,6 +4,7 @@ use cubecl_core::{ ir::{self as gpu, ConstantScalarValue}, Compiler, }; +use cubecl_runtime::ExecutionMode; use super::{Instruction, WarpInstruction}; @@ -25,13 +26,20 @@ pub struct CudaCompiler { num_inputs: usize, num_outputs: usize, items: HashSet, + strategy: ExecutionMode, } impl Compiler for CudaCompiler { type Representation = super::ComputeKernel; - fn compile(kernel: cubecl_core::ir::KernelDefinition) -> Self::Representation { - let compiler = Self::default(); + fn compile( + kernel: cubecl_core::ir::KernelDefinition, + strategy: ExecutionMode, + ) -> Self::Representation { + let compiler = Self { + strategy, + ..Self::default() + }; compiler.compile_shader(kernel) } @@ -93,9 +101,9 @@ impl CudaCompiler { } } - fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec { + fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec { let mut instructions = Vec::new(); - let processing = value.process(); + let processing = scope.process(); for var in processing.variables { if let gpu::Variable::Slice { .. } = var { @@ -109,7 +117,7 @@ impl CudaCompiler { processing .operations .into_iter() - .for_each(|op| self.compile_operation(&mut instructions, op, value)); + .for_each(|op| self.compile_operation(&mut instructions, op, scope)); instructions } @@ -121,7 +129,7 @@ impl CudaCompiler { scope: &mut gpu::Scope, ) { match operation { - gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)), + gpu::Operation::Operator(op) => self.compile_instruction(op, instructions, scope), gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope), gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)), gpu::Operation::Branch(val) => self.compile_branch(instructions, val), @@ -321,57 +329,127 @@ impl CudaCompiler { } } - fn compile_instruction(&mut self, value: gpu::Operator) -> Instruction { + fn compile_instruction( + &mut self, + value: gpu::Operator, + instructions: &mut Vec, + scope: &mut gpu::Scope, + ) { match value { - gpu::Operator::Add(op) => Instruction::Add(self.compile_binary(op)), - gpu::Operator::Mul(op) => Instruction::Mul(self.compile_binary(op)), - gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)), - gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)), - gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)), - gpu::Operator::Slice(op) => Instruction::Slice { + gpu::Operator::Add(op) => instructions.push(Instruction::Add(self.compile_binary(op))), + gpu::Operator::Mul(op) => instructions.push(Instruction::Mul(self.compile_binary(op))), + gpu::Operator::Div(op) => instructions.push(Instruction::Div(self.compile_binary(op))), + gpu::Operator::Sub(op) => instructions.push(Instruction::Sub(self.compile_binary(op))), + gpu::Operator::Assign(op) => { + instructions.push(Instruction::Assign(self.compile_unary(op))) + } + gpu::Operator::Slice(op) => instructions.push(Instruction::Slice { input: self.compile_variable(op.input), start: self.compile_variable(op.start), end: self.compile_variable(op.end), out: self.compile_variable(op.out), - }, - gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)), - gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)), - gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)), + }), + gpu::Operator::Index(op) => { + if let ExecutionMode::Checked = self.strategy { + if has_length(&op.lhs) { + self.compile_procedure( + instructions, + gpu::Procedure::CheckedIndex(gpu::CheckedIndex { + lhs: op.lhs, + rhs: op.rhs, + out: op.out, + }), + scope, + ); + + return; + } + }; + + instructions.push(Instruction::Index(self.compile_binary(op))); + } + gpu::Operator::UncheckedIndex(op) => { + instructions.push(Instruction::Index(self.compile_binary(op))) + } + gpu::Operator::IndexAssign(op) => { + if let ExecutionMode::Checked = self.strategy { + if has_length(&op.out) { + self.compile_procedure( + instructions, + gpu::Procedure::CheckedIndexAssign(gpu::CheckedIndexAssign { + lhs: op.lhs, + rhs: op.rhs, + out: op.out, + }), + scope, + ); + return; + } + }; + + instructions.push(Instruction::IndexAssign(self.compile_binary(op))); + } gpu::Operator::UncheckedIndexAssign(op) => { - Instruction::IndexAssign(self.compile_binary(op)) - } - gpu::Operator::Modulo(op) => Instruction::Modulo(self.compile_binary(op)), - gpu::Operator::Equal(op) => Instruction::Equal(self.compile_binary(op)), - gpu::Operator::Lower(op) => Instruction::Lower(self.compile_binary(op)), - gpu::Operator::Greater(op) => Instruction::Greater(self.compile_binary(op)), - gpu::Operator::LowerEqual(op) => Instruction::LowerEqual(self.compile_binary(op)), - gpu::Operator::GreaterEqual(op) => Instruction::GreaterEqual(self.compile_binary(op)), - gpu::Operator::Abs(op) => Instruction::Abs(self.compile_unary(op)), - gpu::Operator::Exp(op) => Instruction::Exp(self.compile_unary(op)), - gpu::Operator::Log(op) => Instruction::Log(self.compile_unary(op)), - gpu::Operator::Log1p(op) => Instruction::Log1p(self.compile_unary(op)), - gpu::Operator::Cos(op) => Instruction::Cos(self.compile_unary(op)), - gpu::Operator::Sin(op) => Instruction::Sin(self.compile_unary(op)), - gpu::Operator::Tanh(op) => Instruction::Tanh(self.compile_unary(op)), - gpu::Operator::Powf(op) => Instruction::Powf(self.compile_binary(op)), - gpu::Operator::Sqrt(op) => Instruction::Sqrt(self.compile_unary(op)), - gpu::Operator::Erf(op) => Instruction::Erf(self.compile_unary(op)), - gpu::Operator::And(op) => Instruction::And(self.compile_binary(op)), - gpu::Operator::Or(op) => Instruction::Or(self.compile_binary(op)), - gpu::Operator::Not(op) => Instruction::Not(self.compile_unary(op)), - gpu::Operator::Max(op) => Instruction::Max(self.compile_binary(op)), - gpu::Operator::Min(op) => Instruction::Min(self.compile_binary(op)), - gpu::Operator::NotEqual(op) => Instruction::NotEqual(self.compile_binary(op)), - gpu::Operator::BitwiseAnd(op) => Instruction::BitwiseAnd(self.compile_binary(op)), - gpu::Operator::BitwiseXor(op) => Instruction::BitwiseXor(self.compile_binary(op)), - gpu::Operator::ShiftLeft(op) => Instruction::ShiftLeft(self.compile_binary(op)), - gpu::Operator::ShiftRight(op) => Instruction::ShiftRight(self.compile_binary(op)), - gpu::Operator::Clamp(op) => Instruction::Clamp { + instructions.push(Instruction::IndexAssign(self.compile_binary(op))) + } + gpu::Operator::Modulo(op) => { + instructions.push(Instruction::Modulo(self.compile_binary(op))) + } + gpu::Operator::Equal(op) => { + instructions.push(Instruction::Equal(self.compile_binary(op))) + } + gpu::Operator::Lower(op) => { + instructions.push(Instruction::Lower(self.compile_binary(op))) + } + gpu::Operator::Greater(op) => { + instructions.push(Instruction::Greater(self.compile_binary(op))) + } + gpu::Operator::LowerEqual(op) => { + instructions.push(Instruction::LowerEqual(self.compile_binary(op))) + } + gpu::Operator::GreaterEqual(op) => { + instructions.push(Instruction::GreaterEqual(self.compile_binary(op))) + } + gpu::Operator::Abs(op) => instructions.push(Instruction::Abs(self.compile_unary(op))), + gpu::Operator::Exp(op) => instructions.push(Instruction::Exp(self.compile_unary(op))), + gpu::Operator::Log(op) => instructions.push(Instruction::Log(self.compile_unary(op))), + gpu::Operator::Log1p(op) => { + instructions.push(Instruction::Log1p(self.compile_unary(op))) + } + gpu::Operator::Cos(op) => instructions.push(Instruction::Cos(self.compile_unary(op))), + gpu::Operator::Sin(op) => instructions.push(Instruction::Sin(self.compile_unary(op))), + gpu::Operator::Tanh(op) => instructions.push(Instruction::Tanh(self.compile_unary(op))), + gpu::Operator::Powf(op) => { + instructions.push(Instruction::Powf(self.compile_binary(op))) + } + gpu::Operator::Sqrt(op) => instructions.push(Instruction::Sqrt(self.compile_unary(op))), + gpu::Operator::Erf(op) => instructions.push(Instruction::Erf(self.compile_unary(op))), + gpu::Operator::And(op) => instructions.push(Instruction::And(self.compile_binary(op))), + gpu::Operator::Or(op) => instructions.push(Instruction::Or(self.compile_binary(op))), + gpu::Operator::Not(op) => instructions.push(Instruction::Not(self.compile_unary(op))), + gpu::Operator::Max(op) => instructions.push(Instruction::Max(self.compile_binary(op))), + gpu::Operator::Min(op) => instructions.push(Instruction::Min(self.compile_binary(op))), + gpu::Operator::NotEqual(op) => { + instructions.push(Instruction::NotEqual(self.compile_binary(op))) + } + gpu::Operator::BitwiseAnd(op) => { + instructions.push(Instruction::BitwiseAnd(self.compile_binary(op))) + } + gpu::Operator::BitwiseXor(op) => { + instructions.push(Instruction::BitwiseXor(self.compile_binary(op))) + } + gpu::Operator::ShiftLeft(op) => { + instructions.push(Instruction::ShiftLeft(self.compile_binary(op))) + } + gpu::Operator::ShiftRight(op) => { + instructions.push(Instruction::ShiftRight(self.compile_binary(op))) + } + gpu::Operator::Clamp(op) => instructions.push(Instruction::Clamp { input: self.compile_variable(op.input), min_value: self.compile_variable(op.min_value), max_value: self.compile_variable(op.max_value), out: self.compile_variable(op.out), - }, + }), gpu::Operator::Recip(op) => { let elem = op.input.item().elem(); let lhs = match elem { @@ -380,22 +458,27 @@ impl CudaCompiler { gpu::Elem::UInt => ConstantScalarValue::UInt(1), gpu::Elem::Bool => ConstantScalarValue::Bool(true), }; - Instruction::Div(super::BinaryInstruction { + + instructions.push(Instruction::Div(super::BinaryInstruction { lhs: super::Variable::ConstantScalar(lhs, self.compile_elem(elem)), rhs: self.compile_variable(op.input), out: self.compile_variable(op.out), - }) + })) + } + gpu::Operator::Floor(op) => { + instructions.push(Instruction::Floor(self.compile_unary(op))) } - gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)), - gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)), - gpu::Operator::Remainder(_op) => todo!(), - gpu::Operator::Fma(op) => Instruction::Fma { + gpu::Operator::Ceil(op) => instructions.push(Instruction::Ceil(self.compile_unary(op))), + gpu::Operator::Remainder(op) => { + instructions.push(Instruction::Modulo(self.compile_binary(op))) + } + gpu::Operator::Fma(op) => instructions.push(Instruction::Fma { a: self.compile_variable(op.a), b: self.compile_variable(op.b), c: self.compile_variable(op.c), out: self.compile_variable(op.out), - }, - } + }), + }; } fn compile_binary(&mut self, value: gpu::BinaryOperator) -> super::BinaryInstruction { @@ -585,3 +668,12 @@ impl CudaCompiler { } } } + +fn has_length(var: &gpu::Variable) -> bool { + matches!( + var, + gpu::Variable::GlobalInputArray { .. } + | gpu::Variable::GlobalOutputArray { .. } + | gpu::Variable::Slice { .. } + ) +} diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs index b16f24ab4..44d5fae63 100644 --- a/crates/cubecl-cuda/src/compiler/binary.rs +++ b/crates/cubecl-cuda/src/compiler/binary.rs @@ -1,9 +1,9 @@ use super::{Component, Elem, Variable}; -use std::fmt::Display; +use std::fmt::{Display, Formatter}; pub trait Binary { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -13,7 +13,7 @@ pub trait Binary { } fn format_scalar( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: Lhs, rhs: Rhs, out: Out, @@ -25,7 +25,7 @@ pub trait Binary { Out: Component; fn unroll_vec( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -140,11 +140,11 @@ pub struct Index; impl Binary for IndexAssign { fn format_scalar( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: Lhs, rhs: Rhs, out: Out, - _elem: Elem, + elem: Elem, ) -> std::fmt::Result where Lhs: Component, @@ -154,10 +154,10 @@ impl Binary for IndexAssign { let item_out = out.item(); let item_rhs = rhs.item(); - if item_out.vectorization != item_rhs.vectorization { + let format_vec = |f: &mut Formatter<'_>, cast: bool| { let is_vec_native = item_out.is_vec_native(); f.write_str("{\n")?; - let var = "scalar_broadcasted"; + let var = "broadcasted"; f.write_fmt(format_args!("{item_out} {var};\n"))?; for i in 0..item_out.vectorization { if is_vec_native { @@ -168,14 +168,31 @@ impl Binary for IndexAssign { 3 => 'w', _ => panic!("Invalid"), }; - f.write_fmt(format_args!("{var}.{char} = {rhs};\n"))?; + if cast { + f.write_fmt(format_args!("{var}.{char} = {}({});\n", elem, rhs.index(i)))?; + } else { + f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?; + } + } else if cast { + f.write_fmt(format_args!("{var}.i_{i} = {}({});\n", elem, rhs.index(i)))?; } else { - f.write_fmt(format_args!("{var}.i_{i} = {rhs};\n"))?; + f.write_fmt(format_args!("{var}.i_{i} = {};\n", rhs.index(i)))?; } } f.write_fmt(format_args!("{out}[{lhs}] = {var};\n"))?; f.write_str("}")?; + Ok(()) + }; + + if item_out.vectorization != item_rhs.vectorization { + format_vec(f, item_out != item_rhs) + } else if elem != item_rhs.elem { + if item_out.vectorization > 1 { + format_vec(f, true)?; + } else { + f.write_fmt(format_args!("{out}[{lhs}] = {elem}({rhs});\n"))?; + } Ok(()) } else { f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n")) @@ -183,7 +200,7 @@ impl Binary for IndexAssign { } fn unroll_vec( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -195,8 +212,8 @@ impl Binary for IndexAssign { } for i in 0..index { - let lhsi = lhs.index(i, false); - let rhsi = rhs.index(i, false); + let lhsi = lhs.index(i, lhs.item().is_optimized()); + let rhsi = rhs.index(i, rhs.item().is_optimized()); Self::format_scalar(f, lhsi, rhsi, *out, elem)?; } @@ -204,7 +221,7 @@ impl Binary for IndexAssign { } fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -227,7 +244,7 @@ impl Binary for IndexAssign { impl Binary for Index { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -245,18 +262,55 @@ impl Binary for Index { } fn format_scalar( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: Lhs, rhs: Rhs, out: Out, - _elem: Elem, + elem: Elem, ) -> std::fmt::Result where Lhs: Component, Rhs: Component, Out: Component, { - f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n")) + let item_out = out.item(); + let item_lhs = lhs.item(); + + let format_vec = |f: &mut Formatter<'_>| { + let is_vec_native = item_out.is_vec_native(); + f.write_str("{\n")?; + let var = "broadcasted"; + f.write_fmt(format_args!("{item_out} {var};\n"))?; + for i in 0..item_out.vectorization { + if is_vec_native { + let char = match i { + 0 => 'x', + 1 => 'y', + 2 => 'z', + 3 => 'w', + _ => panic!("Invalid"), + }; + f.write_fmt(format_args!("{var}.{char} = {elem}({lhs}[{rhs}].i_{i});\n",))?; + } else { + f.write_fmt(format_args!("{var}.i_{i} = {elem}({lhs}[{rhs}].i_{i});\n",))?; + } + } + f.write_fmt(format_args!("{out} = {var};\n"))?; + f.write_str("}")?; + + Ok(()) + }; + + if elem != item_lhs.elem { + if item_out.vectorization > 1 { + format_vec(f)?; + } else { + f.write_fmt(format_args!("{out} = {elem}({lhs}[{rhs}]);\n"))?; + } + Ok(()) + } else { + f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n")) + } } } @@ -285,7 +339,7 @@ struct IndexAssignVector; impl IndexVector { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -307,7 +361,7 @@ impl IndexVector { impl IndexAssignVector { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, diff --git a/crates/cubecl-cuda/src/compiler/element.rs b/crates/cubecl-cuda/src/compiler/element.rs index 22ce1ec16..f1570105b 100644 --- a/crates/cubecl-cuda/src/compiler/element.rs +++ b/crates/cubecl-cuda/src/compiler/element.rs @@ -54,6 +54,7 @@ impl Display for Item { pub trait Component: Display { fn item(&self) -> Item; + fn index(&self, index: usize) -> IndexedVariable; fn elem(&self) -> Elem { *self.item().elem() } @@ -63,8 +64,16 @@ impl Component for IndexedVariable { fn item(&self) -> Item { self.var.item() } + + fn index(&self, index: usize) -> IndexedVariable { + self.var.index(index, self.var.is_optimized()) + } } impl Component for Variable { + fn index(&self, index: usize) -> IndexedVariable { + self.index(index, self.is_optimized()) + } + fn item(&self) -> Item { match self { Variable::GlobalInputArray(_, e) => *e, diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 47f2621bf..729b47c75 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -9,6 +9,7 @@ use cubecl_core::ir::CubeDim; use cubecl_core::FeatureSet; use cubecl_core::{prelude::*, KernelId}; use cubecl_runtime::debug::DebugLogger; +use cubecl_runtime::ExecutionMode; use cubecl_runtime::{ memory_management::MemoryManagement, server::{self, ComputeServer}, @@ -111,15 +112,17 @@ impl> ComputeServer for CudaServer { server::Handle::new(handle) } - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, + mode: ExecutionMode, ) { let arch = self.minimum_arch_version; - let kernel_id = kernel.id(); + let mut kernel_id = kernel.id(); + kernel_id.mode(mode); let count = match count { CubeCount::Static(x, y, z) => (x, y, z), @@ -140,7 +143,7 @@ impl> ComputeServer for CudaServer { let (ctx, logger) = self.get_context_with_logger(); if !ctx.module_names.contains_key(&kernel_id) { - ctx.compile_kernel(&kernel_id, kernel, arch, logger); + ctx.compile_kernel(&kernel_id, kernel, arch, logger, mode); } let resources = bindings @@ -198,8 +201,9 @@ impl> CudaContext { kernel: Box, arch: i32, logger: &mut DebugLogger, + mode: ExecutionMode, ) { - let mut kernel_compiled = kernel.compile(); + let mut kernel_compiled = kernel.compile(mode); if logger.is_activated() { kernel_compiled.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone())); @@ -231,7 +235,7 @@ impl> CudaContext { message += format!("\n {line}").as_str(); } } - let source = kernel.compile().source; + let source = kernel.compile(mode).source; panic!("{message}\n[Source] \n{source}"); }; cudarc::nvrtc::result::get_ptx(program).unwrap() diff --git a/crates/cubecl-linalg/src/matmul/cmma/base.rs b/crates/cubecl-linalg/src/matmul/cmma/base.rs index b876d197d..f534d820d 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/base.rs @@ -4,7 +4,7 @@ use cubecl_core::prelude::*; use super::block_loop::block_loop; use super::config::CmmaConfig; -#[cube(launch)] +#[cube(launch_unchecked)] #[allow(unused_mut)] pub fn cmma_kernel( lhs: &Tensor, diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index bbbe24852..396d68715 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -157,13 +157,15 @@ fn matmul_cmma_ref_no_check( let cube_dim = cmma_cube_dim(); let launch_config = CmmaLaunchConfig::default(); - cmma_kernel::launch::( - client, - cube_count, - cube_dim, - TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), - TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), - TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), - CmmaConfig::new(m, k, n, launch_config), - ); + unsafe { + cmma_kernel::launch_unchecked::( + client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(lhs.handle, lhs.strides, lhs.shape, lhs_vectorization), + TensorArg::from_raw_parts(rhs.handle, rhs.strides, rhs.shape, rhs_vectorization), + TensorArg::from_raw_parts(out.handle, out.strides, out.shape, out_vectorization), + CmmaConfig::new(m, k, n, launch_config), + ); + } } diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index de836ee8d..5c208cc0b 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -10,7 +10,7 @@ use crate::matmul::tests::test_utils::{ assert_equals, cmma_available, create_empty, range_tensor_f16, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn compute_loop_test( lhs_tensor: &Tensor, rhs_tensor: &Tensor, @@ -84,18 +84,20 @@ pub fn compute_loop_k_test(device: &R::Device) { unroll: false, }; - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, m * n), - UInt::new(m as u32), - UInt::new(k as u32), - UInt::new(n as u32), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), + ArrayArg::from_raw_parts(&results, m * n, 1), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), + config, + ); + }; let expected = &[ 1610496., 1614832., 1619168., 1623504., 1627840., 1632176., 1636512., 1640848., 1645184., @@ -160,18 +162,20 @@ pub fn compute_loop_warp_test(device: &R::Device) { unroll: false, }; - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, m * n), - UInt::new(m as u32), - UInt::new(k as u32), - UInt::new(n as u32), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), + ArrayArg::from_raw_parts(&results, m * n, 1), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), + config, + ); + }; let expected = &[ 1610496., 1614832., 1619168., 1623504., 1627840., 1632176., 1636512., 1640848., 1645184., @@ -265,18 +269,20 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De unroll: false, }; - compute_loop_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, m * n), - UInt::new(m as u32), - UInt::new(k as u32), - UInt::new(n as u32), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), + ArrayArg::from_raw_parts(&results, m * n, 1), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), + config, + ); + }; let expected = &[ 1610496.0, 1614832.0, 1619168.0, 1623504.0, 1627840.0, 1632176.0, 1636512.0, 1640848.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index 08bed0ed0..33521c561 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -8,7 +8,7 @@ use crate::matmul::{ tests::test_utils::range_tensor, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn load_lhs_test( lhs_tensor: &Tensor, lhs_sm_arr: &mut Array, @@ -41,7 +41,7 @@ fn load_lhs_test( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn load_rhs_test( rhs_tensor: &Tensor, rhs_sm_arr: &mut Array, @@ -93,23 +93,25 @@ pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -150,23 +152,25 @@ pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -207,23 +211,25 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, @@ -269,23 +275,25 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device unroll: false, }; - load_lhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(12), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(12), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, @@ -329,23 +337,25 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(12), - ScalarArg::new(12), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(12), + ScalarArg::new(12), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, @@ -389,23 +399,25 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(12), - ScalarArg::new(12), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(12), + ScalarArg::new(12), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, @@ -448,23 +460,25 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, @@ -510,23 +524,25 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., 81., @@ -571,23 +587,25 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., @@ -635,23 +653,25 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., @@ -699,23 +719,25 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., 81., @@ -760,23 +782,25 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(32), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), + ScalarArg::new(32), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 96., 97., @@ -821,23 +845,25 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(32), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + 4, + ), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), + ScalarArg::new(32), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 2048., 2049., 2050., 2051., 2052., 2053., 2054., 2055., 2056., 2057., 2058., 2059., 2060., diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index d593eb013..c9133eca6 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -8,7 +8,7 @@ use crate::matmul::{ tests::test_utils::range_tensor, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn write_output_test( out: &mut Tensor, acc_sm_arr: &mut Array, @@ -60,16 +60,18 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 256.0, @@ -126,16 +128,18 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -202,16 +206,18 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -273,16 +279,18 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -344,16 +352,18 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -411,16 +421,18 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -527,16 +539,18 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index 21447b0cd..7a3db32bd 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -13,7 +13,7 @@ use crate::matmul::{ }, }; -#[cube(launch)] +#[cube(launch_unchecked)] #[allow(unused_mut)] fn tile_outer_product_test( register_m: Array, @@ -50,15 +50,17 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) const SOME_DIM: usize = 12; let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); - tile_outer_product_test::launch::( - &client, - cube_count, - cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), - config, - ); + unsafe { + tile_outer_product_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + ArrayArg::from_raw_parts(®ister_m, 4, 1), + ArrayArg::from_raw_parts(®ister_n, 4, 1), + ArrayArg::from_raw_parts(&results, 16, 1), + config, + ); + }; let expected = &[ 64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0, @@ -67,7 +69,7 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) assert_equals::(&client, results, expected); } -#[cube(launch)] +#[cube(launch_unchecked)] fn compute_loop_test( lhs: &Tensor, rhs: &Tensor, @@ -124,15 +126,17 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { const SOME_DIM: usize = 12; let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); - tile_outer_product_test::launch::( - &client, - cube_count, - cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), - config, - ); + unsafe { + tile_outer_product_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + ArrayArg::from_raw_parts(®ister_m, 4, 1), + ArrayArg::from_raw_parts(®ister_n, 4, 1), + ArrayArg::from_raw_parts(&results, 16, 1), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, @@ -152,19 +156,21 @@ pub fn compute_loop_unit_test(device: &R::Device) { const SOME_DIM: usize = 12; let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ScalarArg::new(0), - ScalarArg::new(0), - ArrayArg::new(&results, 16), - UInt::new(16), - UInt::new(16), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, TILE_SIZE as u8), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), + ScalarArg::new(0), + ScalarArg::new(0), + ArrayArg::from_raw_parts(&results, 16, 1), + UInt::new(16), + UInt::new(16), + config, + ); + }; let expected = &[ 8960.0, 9184.0, 9408.0, 9632.0, 9184.0, 9416.0, 9648.0, 9880.0, 9408.0, 9648.0, 9888.0, @@ -184,19 +190,21 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { let config = make_tiling2d_config(4, 8, 4); - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ScalarArg::new(4), - ScalarArg::new(4), - ArrayArg::new(&results, 16), - UInt::new(8), - UInt::new(8), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, TILE_SIZE as u8), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ArrayArg::from_raw_parts(&results, 16, 1), + UInt::new(8), + UInt::new(8), + config, + ); + }; let expected = &[ 1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index b8059b522..1ea5a3fdc 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -15,7 +15,7 @@ use crate::matmul::{ }, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn load_tensor_test( tensor: &Tensor, sm_out: &mut Array, @@ -80,7 +80,7 @@ fn load_tensor_test( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn load_tensor_permuted_test( tensor: &Tensor, sm_out: &mut Array, @@ -147,7 +147,7 @@ fn load_tensor_permuted_test( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn load_tensor_multiple_tiles_test( tensor: &Tensor, sm_out: &mut Array, @@ -222,18 +222,20 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -255,21 +257,23 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic let config = make_tiling2d_config(5, 1, 1); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - vectorization_factor as u8, - &lhs.handle, - &lhs.strides, - &lhs.shape, - ), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - true, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts( + &lhs.handle, + &lhs.strides, + &lhs.shape, + vectorization_factor as u8, + ), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(0), + config, + true, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -290,16 +294,18 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - true, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(0), + config, + true, + ); + }; let expected = &[ 0.0, 8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0, 57.0, @@ -321,16 +327,18 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 16); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 8.0, 24.0, 40.0, 56.0, 72.0, 88.0, 104.0, 120.0, 9.0, 25.0, 41.0, 57.0, 73.0, 89.0, 105.0, @@ -352,18 +360,20 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 16, 16); - load_tensor_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -384,16 +394,18 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - false, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(0), + config, + false, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, @@ -415,16 +427,18 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, @@ -446,18 +460,20 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_permuted_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -479,18 +495,20 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { let config = make_tiling2d_config(m, k, 8); - load_tensor_permuted_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -511,18 +529,20 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_permuted_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -544,18 +564,20 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic let config = make_tiling2d_config(8, k, n); - load_tensor_permuted_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index f0a9e5939..41c2f2931 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -14,7 +14,7 @@ use crate::matmul::{ }, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn write_to_output_test( out: &mut Tensor, results: &mut Array, @@ -35,7 +35,7 @@ fn write_to_output_test( write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); } -#[cube(launch)] +#[cube(launch_unchecked)] fn write_results_to_output_out_of_bounds_test( out: &mut Tensor, results: &mut Array, @@ -66,14 +66,16 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { let config = make_tiling2d_config(6, 8, 8); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, TILE_SIZE as u8), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -93,14 +95,16 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 4); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, TILE_SIZE as u8), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -120,14 +124,16 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & let config = make_tiling2d_config(8, 8, 8); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization as u8), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -149,14 +155,16 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization as u8), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -178,14 +186,16 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De let config = make_tiling2d_config(5, 8, 1); - write_results_to_output_out_of_bounds_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&results.handle, 16), - config, - ); + unsafe { + write_results_to_output_out_of_bounds_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization), + ArrayArg::from_raw_parts(&results.handle, 16, 1), + config, + ); + }; let expected = &[0.0, 1.0, 2.0, 3.0, 0.0]; assert_equals::(&client, out.handle, expected); diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/base.rs b/crates/cubecl-linalg/src/matmul/tiling2d/base.rs index 728ca6cf9..4418a3d56 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/base.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/base.rs @@ -6,7 +6,7 @@ use super::{block_loop::block_loop, config::CubeTiling2dConfig}; /// Most common tile size, the one used in most tests. pub(crate) const TILE_SIZE: usize = 4; -#[cube(launch)] +#[cube(launch_unchecked)] #[allow(unused_mut)] pub fn tiling2d_cube_kernel( lhs: &Tensor, diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index 5499d3e81..8f29adf60 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -125,13 +125,15 @@ fn matmul_tiling_2d_ref_no_check( let cube_dim = tiling2d_cube_dim(&config); let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); - tiling2d_cube_kernel::launch::( - client, - cube_count, - cube_dim, - TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), - TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), - TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), - cube_config, - ); + unsafe { + tiling2d_cube_kernel::launch_unchecked::( + client, + cube_count, + cube_dim, + TensorArg::from_raw_parts(lhs.handle, lhs.strides, lhs.shape, lhs_vectorization), + TensorArg::from_raw_parts(rhs.handle, rhs.strides, rhs.shape, rhs_vectorization), + TensorArg::from_raw_parts(out.handle, out.strides, out.shape, out_vectorization), + cube_config, + ); + } } diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index 0eb5b1cfb..8d37e1bee 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -91,6 +91,15 @@ where } } + /// Return the reference to a tensor argument. + pub fn as_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + let handle: TensorHandleRef<'a, R> = self.as_ref(); + + unsafe { + TensorArg::from_raw_parts(handle.handle, handle.strides, handle.shape, vectorisation) + } + } + fn contiguous_strides(shape: &[usize]) -> Vec { let mut strides = Vec::with_capacity(shape.len()); @@ -124,12 +133,14 @@ where cube_dim, ); - init::zeros_array::launch::( - client, - cube_count, - cube_dim, - ArrayArg::vectorized(vectorization_factor, &handle, num_elements), - ); + unsafe { + init::zeros_array::launch_unchecked::( + client, + cube_count, + cube_dim, + ArrayArg::from_raw_parts(&handle, num_elements, vectorization_factor), + ) + }; Self::new(shape, strides, handle) } @@ -139,7 +150,7 @@ pub(crate) mod init { use cubecl::prelude::*; use cubecl_core as cubecl; - #[cube(launch)] + #[cube(launch_unchecked)] pub fn zeros_array(output: &mut Array) { if ABSOLUTE_POS < output.len() { output[ABSOLUTE_POS] = C::from_int(0); diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 5f9513378..e26b3afa0 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -73,18 +73,8 @@ pub fn into_contiguous( client, cube_count, cube_dim, - TensorArg::vectorized( - vectorization_factor, - input.handle, - input.strides, - input.shape, - ), - TensorArg::vectorized( - vectorization_factor, - &output.handle, - &output.strides, - &output.shape, - ), + input.as_tensor_arg(vectorization_factor), + output.as_ref().as_tensor_arg(vectorization_factor), Some(UInt::new(rank as u32)), ); diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs index 49ed4dcba..c4ec8647f 100644 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ b/crates/cubecl-macros/src/codegen_function/launch.rs @@ -13,13 +13,15 @@ struct Codegen { state_args: Vec, state_inputs: Vec<(Ident, syn::Type)>, state_outputs: Vec<(Ident, syn::Type)>, + unchecked: bool, } impl Codegen { - fn from_sig(sig: &syn::Signature) -> Self { + fn from_sig(sig: &syn::Signature, unchecked: bool) -> Self { let mut codegen = Codegen { name: snake_to_pascal_case(&sig.ident.to_string()), generics: sig.generics.clone(), + unchecked, ..Codegen::default() }; @@ -425,20 +427,30 @@ impl Codegen { } }; - quote::quote! { + let mut tokens = quote::quote! { #settings let kernel = #kernel; #body + }; - launcher.launch(cube_count, kernel, client); + if self.unchecked { + tokens.extend(quote::quote! { + launcher.launch_unchecked(cube_count, kernel, client); + }); + } else { + tokens.extend(quote::quote! { + launcher.launch(cube_count, kernel, client); + }); } + + tokens } } -pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { - let codegen = Codegen::from_sig(sig); +pub fn codegen_launch(sig: &syn::Signature, unchecked: bool) -> TokenStream { + let codegen = Codegen::from_sig(sig, unchecked); let ident = &sig.ident; @@ -453,13 +465,24 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { let (inputs, output) = (codegen.fn_inputs, codegen.fn_output); let doc = format!("Launch the kernel [{ident}()] on the given runtime."); + let maybe_unsafe = if unchecked { + quote::quote! {unsafe} + } else { + quote::quote! {} + }; + let launch_name = if unchecked { + quote::quote! { launch_unchecked} + } else { + quote::quote! { launch} + }; + quote::quote! { #kernel #compile #[allow(clippy::too_many_arguments)] #[doc = #doc] - pub fn launch #generics ( + pub #maybe_unsafe fn #launch_name #generics ( client: &ComputeClient, cube_count: CubeCount, cube_dim: CubeDim, diff --git a/crates/cubecl-macros/src/lib.rs b/crates/cubecl-macros/src/lib.rs index 82630ace6..f30e4a95d 100644 --- a/crates/cubecl-macros/src/lib.rs +++ b/crates/cubecl-macros/src/lib.rs @@ -45,6 +45,7 @@ pub fn module_derive_cube_type(input: TokenStream) -> TokenStream { struct SupportedAttributes { mode: CubeMode, launch: bool, + launch_unchecked: bool, } /// Derive macro for the module. @@ -69,7 +70,12 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { let mut variable_tracker = VariableAnalyzer::create_tracker(&func); - match codegen_cube(&func, &mut variable_tracker, attrs.launch) { + match codegen_cube( + &func, + &mut variable_tracker, + attrs.launch, + attrs.launch_unchecked, + ) { Ok(code) => code.into(), Err(err) => err.into(), } @@ -78,6 +84,7 @@ fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { fn parse_attributes(args: &Punctuated) -> SupportedAttributes { let mut mode = CubeMode::Default; let mut launch = false; + let mut launch_unchecked = false; for arg in args.iter() { match arg { @@ -90,7 +97,12 @@ fn parse_attributes(args: &Punctuated) -> SupportedAttributes { "launch" => { launch = true; } - _ => panic!("Attribute {ident} is not supported"), + "launch_unchecked" => { + launch_unchecked = true; + } + _ => { + panic!("Attribute {ident} is not supported") + } } } else { panic!("Only ident attribute supported"); @@ -101,7 +113,11 @@ fn parse_attributes(args: &Punctuated) -> SupportedAttributes { } } - SupportedAttributes { mode, launch } + SupportedAttributes { + mode, + launch, + launch_unchecked, + } } /// Generate the expanded version of a function marked with the cube macro @@ -109,6 +125,7 @@ fn codegen_cube( func: &syn::ItemFn, variable_tracker: &mut VariableTracker, launch: bool, + launch_unchecked: bool, ) -> Result { let signature = expand_sig( &func.sig, @@ -149,12 +166,18 @@ fn codegen_cube( "function " }; - let launch = if launch { - codegen_launch(&func.sig) + let mut launch = if launch { + codegen_launch(&func.sig, false) } else { quote::quote! {} }; + launch.extend(if launch_unchecked { + codegen_launch(&func.sig, true) + } else { + quote::quote! {} + }); + let mod_name = &func.sig.ident; let vis = &func.vis; let doc = format!("Module containing the expand {launch_doc}of {mod_name}."); diff --git a/crates/cubecl-runtime/src/compute.rs b/crates/cubecl-runtime/src/base.rs similarity index 92% rename from crates/cubecl-runtime/src/compute.rs rename to crates/cubecl-runtime/src/base.rs index 9a35f5384..5ad3f1031 100644 --- a/crates/cubecl-runtime/src/compute.rs +++ b/crates/cubecl-runtime/src/base.rs @@ -8,6 +8,16 @@ pub struct ComputeRuntime { clients: spin::Mutex>>>, } +/// The kind of execution to be performed. +#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)] +pub enum ExecutionMode { + /// Checked kernels are safe. + #[default] + Checked, + /// Unchecked kernels are unsafe. + Unchecked, +} + impl Default for ComputeRuntime where Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index fcb1cdcb4..c1fb44a22 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -1,6 +1,7 @@ use crate::{ server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, + ExecutionMode, }; use alloc::vec::Vec; use cubecl_common::{reader::Reader, sync_type::SyncType}; @@ -24,11 +25,16 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send fn empty(&self, size: usize) -> Handle; /// Executes the `kernel` over the given `bindings`. - fn execute( + /// + /// # Safety + /// + /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen. + unsafe fn execute( &self, kernel: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, + mode: ExecutionMode, ); /// Perform some synchronization of commands on the server. diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs index aa9dcc25c..e9178ac30 100644 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ b/crates/cubecl-runtime/src/channel/cell.rs @@ -1,6 +1,7 @@ use super::ComputeChannel; use crate::server::{Binding, ComputeServer, Handle}; use crate::storage::ComputeStorage; +use crate::ExecutionMode; use alloc::sync::Arc; use alloc::vec::Vec; use cubecl_common::reader::Reader; @@ -63,15 +64,16 @@ where self.server.borrow_mut().empty(size) } - fn execute( + unsafe fn execute( &self, kernel_description: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, + kind: ExecutionMode, ) { self.server .borrow_mut() - .execute(kernel_description, count, bindings) + .execute(kernel_description, count, bindings, kind) } fn sync(&self, sync_type: SyncType) { diff --git a/crates/cubecl-runtime/src/channel/mpsc.rs b/crates/cubecl-runtime/src/channel/mpsc.rs index 6d55a0302..3488635dc 100644 --- a/crates/cubecl-runtime/src/channel/mpsc.rs +++ b/crates/cubecl-runtime/src/channel/mpsc.rs @@ -5,6 +5,7 @@ use super::ComputeChannel; use crate::{ server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, + ExecutionMode, }; /// Create a channel using a [multi-producer, single-consumer channel to communicate with @@ -40,7 +41,7 @@ where Create(Vec, Callback>), Empty(usize, Callback>), ExecuteKernel( - (Server::Kernel, Server::DispatchOptions), + (Server::Kernel, Server::DispatchOptions, ExecutionMode), Vec>, ), Sync(SyncType, Callback<()>), @@ -76,9 +77,9 @@ where let handle = server.empty(size); callback.send(handle).await.unwrap(); } - Message::ExecuteKernel(kernel, bindings) => { - server.execute(kernel.0, kernel.1, bindings); - } + Message::ExecuteKernel(kernel, bindings) => unsafe { + server.execute(kernel.0, kernel.1, bindings, kernel.2); + }, Message::Sync(sync_type, callback) => { server.sync(sync_type); callback.send(()).await.unwrap(); @@ -151,15 +152,16 @@ where handle_response(response.recv_blocking()) } - fn execute( + unsafe fn execute( &self, kernel: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, + kind: ExecutionMode, ) { self.state .sender - .send_blocking(Message::ExecuteKernel((kernel, count), bindings)) + .send_blocking(Message::ExecuteKernel((kernel, count, kind), bindings)) .unwrap() } diff --git a/crates/cubecl-runtime/src/channel/mutex.rs b/crates/cubecl-runtime/src/channel/mutex.rs index 13f2e12b3..f20fce89a 100644 --- a/crates/cubecl-runtime/src/channel/mutex.rs +++ b/crates/cubecl-runtime/src/channel/mutex.rs @@ -1,6 +1,7 @@ use super::ComputeChannel; use crate::server::{Binding, ComputeServer, Handle}; use crate::storage::ComputeStorage; +use crate::ExecutionMode; use alloc::sync::Arc; use alloc::vec::Vec; use cubecl_common::reader::Reader; @@ -56,13 +57,14 @@ where self.server.lock().empty(size) } - fn execute( + unsafe fn execute( &self, kernel: Server::Kernel, count: Server::DispatchOptions, handles: Vec>, + kind: ExecutionMode, ) { - self.server.lock().execute(kernel, count, handles) + self.server.lock().execute(kernel, count, handles, kind) } fn sync(&self, sync_type: SyncType) { diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 4e2b6a396..25cf9b7b7 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -2,6 +2,7 @@ use crate::{ channel::ComputeChannel, server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, + ExecutionMode, }; use alloc::sync::Arc; use alloc::vec::Vec; @@ -77,7 +78,25 @@ where count: Server::DispatchOptions, bindings: Vec>, ) { - self.channel.execute(kernel, count, bindings) + unsafe { + self.channel + .execute(kernel, count, bindings, ExecutionMode::Checked) + } + } + + /// Executes the `kernel` over the given `bindings` without performing any bound checks. + /// + /// # Safety + /// + /// Without checks, the out-of-bound reads and writes can happen. + pub unsafe fn execute_unchecked( + &self, + kernel: Server::Kernel, + count: Server::DispatchOptions, + bindings: Vec>, + ) { + self.channel + .execute(kernel, count, bindings, ExecutionMode::Unchecked) } /// Wait for the completion of every task in the server. diff --git a/crates/cubecl-runtime/src/lib.rs b/crates/cubecl-runtime/src/lib.rs index 307021bf5..dd278706a 100644 --- a/crates/cubecl-runtime/src/lib.rs +++ b/crates/cubecl-runtime/src/lib.rs @@ -25,8 +25,8 @@ pub mod server; /// Compute Storage module. pub mod storage; -mod compute; -pub use compute::*; +mod base; +pub use base::*; pub use cubecl_common::benchmark; /// Debugging utilities. diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 212461a11..dbc104f5e 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -1,6 +1,7 @@ use crate::{ memory_management::{MemoryHandle, MemoryManagement}, storage::ComputeStorage, + ExecutionMode, }; use alloc::vec::Vec; use core::fmt::Debug; @@ -44,11 +45,16 @@ where /// /// Kernels have mutable access to every resource they are given /// and are responsible of determining which should be read or written. - fn execute( + /// + /// # Safety + /// + /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen. + unsafe fn execute( &mut self, kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, + kind: ExecutionMode, ); /// Wait for the completion of every task in the server. diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index 079ccb72b..2f5ade31f 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -5,6 +5,7 @@ use cubecl_runtime::{ memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, server::{Binding, ComputeServer, Handle}, storage::{BytesResource, BytesStorage}, + ExecutionMode, }; use derive_new::new; @@ -53,11 +54,12 @@ where Handle::new(self.memory_management.reserve(size, || {})) } - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, _count: Self::DispatchOptions, bindings: Vec>, + _mode: ExecutionMode, ) { let mut resources = bindings .into_iter() diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index be4ecfe2d..88bcbeb74 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -16,7 +16,6 @@ default = [ "cubecl-common/default", "cubecl-core/default", ] -autotune = [] std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] [dependencies] diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index fc2f10cc5..c3547ce43 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -2,6 +2,7 @@ use super::{shader::ComputeShader, Item, SharedMemory}; use super::{LocalArray, Subgroup}; use crate::compiler::wgsl; use cubecl_core::ir as cube; +use cubecl_runtime::ExecutionMode; /// Wgsl Compiler. #[derive(Clone, Default)] @@ -33,7 +34,7 @@ impl core::fmt::Debug for WgslCompiler { impl cubecl_core::Compiler for WgslCompiler { type Representation = ComputeShader; - fn compile(shader: cube::KernelDefinition) -> Self::Representation { + fn compile(shader: cube::KernelDefinition, _mode: ExecutionMode) -> Self::Representation { let mut compiler = Self::default(); compiler.compile_shader(shader) } diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index f2dc72577..4d908d99b 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -8,6 +8,7 @@ use cubecl_runtime::{ debug::DebugLogger, memory_management::MemoryManagement, server::{self, ComputeServer}, + ExecutionMode, }; use hashbrown::HashMap; use wgpu::{ @@ -98,31 +99,45 @@ where self.tasks_count += 1; } - fn pipeline(&mut self, kernel: ::Kernel) -> Arc { - let kernel_id = kernel.id(); + fn pipeline( + &mut self, + kernel: ::Kernel, + mode: ExecutionMode, + ) -> Arc { + let mut kernel_id = kernel.id(); + kernel_id.mode(mode); if let Some(pipeline) = self.pipelines.get(&kernel_id) { return pipeline.clone(); } - let mut compile = kernel.compile(); + let mut compile = kernel.compile(mode); if self.logger.is_activated() { compile.debug_info = Some(DebugInformation::new("wgsl", kernel_id.clone())); } let compile = self.logger.debug(compile); - let pipeline = self.compile_source(&compile.source); + let pipeline = self.compile_source(&compile.source, mode); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); pipeline } - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); + fn compile_source(&self, source: &str, mode: ExecutionMode) -> Arc { + let module = match mode { + ExecutionMode::Checked => self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }), + ExecutionMode::Unchecked => unsafe { + self.device + .create_shader_module_unchecked(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }) + }, + }; Arc::new( self.device @@ -283,13 +298,14 @@ where })) } - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, + mode: ExecutionMode, ) { - let pipeline = self.pipeline(kernel); + let pipeline = self.pipeline(kernel, mode); let group_layout = pipeline.get_bind_group_layout(0); let memory_handles = bindings diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs index 17ed0148e..99ab027d2 100644 --- a/crates/cubecl/benches/unary.rs +++ b/crates/cubecl/benches/unary.rs @@ -44,9 +44,9 @@ impl Benchmark for UnaryBench { &self.client, cube_count, cube_dim, - TensorArg::vectorized(self.vectorization, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(self.vectorization, &rhs.handle, &rhs.strides, &rhs.shape), - TensorArg::vectorized(self.vectorization, &out.handle, &out.strides, &out.shape), + lhs.as_arg(self.vectorization), + rhs.as_arg(self.vectorization), + out.as_arg(self.vectorization), ) } diff --git a/crates/cubecl/src/runtime.rs b/crates/cubecl/src/runtime.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 79c834db4..80d76c949 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -1,6 +1,6 @@ use cubecl::prelude::*; -#[cube(launch)] +#[cube(launch_unchecked)] fn gelu_array(input: &Array, output: &mut Array) { if ABSOLUTE_POS < input.len() { output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); @@ -18,13 +18,15 @@ pub fn launch(device: &R::Device) { let output_handle = client.empty(input.len() * core::mem::size_of::()); let input_handle = client.create(f32::as_bytes(input)); - gelu_array::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - ArrayArg::new(&input_handle, input.len()), - ArrayArg::new(&output_handle, input.len()), - ); + unsafe { + gelu_array::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32, 1, 1), + ArrayArg::from_raw_parts(&input_handle, input.len(), 1), + ArrayArg::from_raw_parts(&output_handle, input.len(), 1), + ) + }; let bytes = client.read(output_handle.binding()); let output = f32::from_bytes(&bytes);