From 4011cd99d63fe0c14d75fcc2e6ac903086a9bcfd Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 1 Aug 2024 09:42:31 -0400 Subject: [PATCH 1/9] launch-unchecked --- crates/cubecl-core/src/codegen/compiler.rs | 4 +- crates/cubecl-core/src/compute/kernel.rs | 19 ++++---- crates/cubecl-core/src/compute/launcher.rs | 14 ++++++ crates/cubecl-core/src/id.rs | 8 ++++ crates/cubecl-cuda/src/compiler/base.rs | 10 ++++- crates/cubecl-cuda/src/compute/server.rs | 12 +++-- crates/cubecl-linalg/src/matmul/cmma/base.rs | 2 +- .../cubecl-linalg/src/matmul/cmma/launch.rs | 20 +++++---- .../cubecl-linalg/src/matmul/tiling2d/base.rs | 2 +- .../src/matmul/tiling2d/launch.rs | 20 +++++---- .../src/codegen_function/launch.rs | 35 ++++++++++++--- crates/cubecl-macros/src/lib.rs | 33 +++++++++++--- crates/cubecl-runtime/src/channel/base.rs | 11 +++++ crates/cubecl-runtime/src/channel/cell.rs | 5 ++- crates/cubecl-runtime/src/channel/mpsc.rs | 13 ++++-- crates/cubecl-runtime/src/channel/mutex.rs | 5 ++- crates/cubecl-runtime/src/client.rs | 16 ++++++- crates/cubecl-runtime/src/server.rs | 2 + crates/cubecl-runtime/tests/dummy/server.rs | 2 + .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 6 ++- crates/cubecl-wgpu/src/compute/server.rs | 45 ++++++++++++++----- 21 files changed, 216 insertions(+), 68 deletions(-) diff --git a/crates/cubecl-core/src/codegen/compiler.rs b/crates/cubecl-core/src/codegen/compiler.rs index 4f1e45105..e7728f7cb 100644 --- a/crates/cubecl-core/src/codegen/compiler.rs +++ b/crates/cubecl-core/src/codegen/compiler.rs @@ -1,3 +1,5 @@ +use cubecl_runtime::channel::KernelExecutionStrategy; + use crate::ir::{Elem, KernelDefinition}; use std::fmt::Display; @@ -13,7 +15,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { type Representation: CompilerRepresentation; /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation. - fn compile(kernel: KernelDefinition) -> Self::Representation; + fn compile(kernel: KernelDefinition, kind: KernelExecutionStrategy) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: Elem) -> usize; /// The maximal size of a shared memory diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index 2658d25d9..be07a5a22 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -5,7 +5,10 @@ use std::{ use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId}; use alloc::sync::Arc; -use cubecl_runtime::server::{Binding, ComputeServer}; +use cubecl_runtime::{ + channel::KernelExecutionStrategy, + server::{Binding, ComputeServer}, +}; /// A kernel, compiled in the target language pub struct CompiledKernel { @@ -157,7 +160,7 @@ pub trait CubeTask: Send + Sync { /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> KernelId; /// Compile the kernel into source - fn compile(&self) -> CompiledKernel; + fn compile(&self, kind: KernelExecutionStrategy) -> CompiledKernel; } /// Wraps a [kernel](Kernel) to create a [cube task](CubeTask). @@ -168,10 +171,10 @@ pub struct KernelTask { } impl CubeTask for KernelTask { - fn compile(&self) -> CompiledKernel { + fn compile(&self, strategy: KernelExecutionStrategy) -> CompiledKernel { let gpu_ir = self.kernel_definition.define(); let cube_dim = gpu_ir.cube_dim; - let lower_level_ir = C::compile(gpu_ir); + let lower_level_ir = C::compile(gpu_ir, strategy); let shared_mem_bytes = lower_level_ir.shared_memory_size(); let source = lower_level_ir.to_string(); @@ -190,8 +193,8 @@ impl CubeTask for KernelTask { } impl CubeTask for Arc { - fn compile(&self) -> CompiledKernel { - self.as_ref().compile() + fn compile(&self, kind: KernelExecutionStrategy) -> CompiledKernel { + self.as_ref().compile(kind) } fn id(&self) -> KernelId { @@ -200,8 +203,8 @@ impl CubeTask for Arc { } impl CubeTask for Box { - fn compile(&self) -> CompiledKernel { - self.as_ref().compile() + fn compile(&self, kind: KernelExecutionStrategy) -> CompiledKernel { + self.as_ref().compile(kind) } fn id(&self) -> KernelId { diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 54e733b64..12f76d7a0 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -89,6 +89,20 @@ impl KernelLauncher { client.execute(kernel, cube_count, bindings); } + /// Launch the kernel without check bounds. + pub unsafe fn launch_unchecked( + self, + cube_count: CubeCount, + kernel: K, + client: &ComputeClient, + ) { + let bindings = self.into_bindings(client); + + let kernel = Box::new(KernelTask::::new(kernel)); + + client.execute_unchecked(kernel, cube_count, bindings); + } + /// We need to create the bindings in the same order they are defined in the compilation step. /// /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed diff --git a/crates/cubecl-core/src/id.rs b/crates/cubecl-core/src/id.rs index bfe9132f8..a96a4b10d 100644 --- a/crates/cubecl-core/src/id.rs +++ b/crates/cubecl-core/src/id.rs @@ -1,3 +1,4 @@ +use cubecl_runtime::channel::KernelExecutionStrategy; use std::any::{Any, TypeId}; use std::fmt::Display; use std::hash::{DefaultHasher, Hash, Hasher}; @@ -8,6 +9,7 @@ use std::sync::Arc; pub struct KernelId { type_id: core::any::TypeId, info: Option, + kind: Option, } impl Display for KernelId { @@ -25,6 +27,7 @@ impl KernelId { Self { type_id: core::any::TypeId::of::(), info: None, + kind: None, } } @@ -39,6 +42,11 @@ impl KernelId { self.info = Some(Info::new(info)); self } + + /// Set the kind of checking strategy. + pub fn kind(&mut self, kind: KernelExecutionStrategy) { + self.kind = Some(kind); + } } /// Extra information diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 5966a1ba7..39f5b1fd1 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; use cubecl_core::{ + channel::KernelExecutionStrategy, ir::{self as gpu, ConstantScalarValue}, Compiler, }; @@ -25,13 +26,18 @@ pub struct CudaCompiler { num_inputs: usize, num_outputs: usize, items: HashSet, + strategy: KernelExecutionStrategy, } impl Compiler for CudaCompiler { type Representation = super::ComputeKernel; - fn compile(kernel: cubecl_core::ir::KernelDefinition) -> Self::Representation { - let compiler = Self::default(); + fn compile( + kernel: cubecl_core::ir::KernelDefinition, + strategy: KernelExecutionStrategy, + ) -> Self::Representation { + let mut compiler = Self::default(); + compiler.strategy = strategy; compiler.compile_shader(kernel) } diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 47f2621bf..0ab57a568 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -4,6 +4,7 @@ use super::storage::CudaStorage; use super::CudaResource; use cubecl_common::reader::{reader_from_concrete, Reader}; use cubecl_common::sync_type::SyncType; +use cubecl_core::channel::KernelExecutionStrategy; use cubecl_core::compute::DebugInformation; use cubecl_core::ir::CubeDim; use cubecl_core::FeatureSet; @@ -116,10 +117,12 @@ impl> ComputeServer for CudaServer { kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, + strategy: KernelExecutionStrategy, ) { let arch = self.minimum_arch_version; - let kernel_id = kernel.id(); + let mut kernel_id = kernel.id(); + kernel_id.kind(strategy); let count = match count { CubeCount::Static(x, y, z) => (x, y, z), @@ -140,7 +143,7 @@ impl> ComputeServer for CudaServer { let (ctx, logger) = self.get_context_with_logger(); if !ctx.module_names.contains_key(&kernel_id) { - ctx.compile_kernel(&kernel_id, kernel, arch, logger); + ctx.compile_kernel(&kernel_id, kernel, arch, logger, strategy); } let resources = bindings @@ -198,8 +201,9 @@ impl> CudaContext { kernel: Box, arch: i32, logger: &mut DebugLogger, + strategy: KernelExecutionStrategy, ) { - let mut kernel_compiled = kernel.compile(); + let mut kernel_compiled = kernel.compile(strategy); if logger.is_activated() { kernel_compiled.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone())); @@ -231,7 +235,7 @@ impl> CudaContext { message += format!("\n {line}").as_str(); } } - let source = kernel.compile().source; + let source = kernel.compile(strategy).source; panic!("{message}\n[Source] \n{source}"); }; cudarc::nvrtc::result::get_ptx(program).unwrap() diff --git a/crates/cubecl-linalg/src/matmul/cmma/base.rs b/crates/cubecl-linalg/src/matmul/cmma/base.rs index b876d197d..f534d820d 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/base.rs @@ -4,7 +4,7 @@ use cubecl_core::prelude::*; use super::block_loop::block_loop; use super::config::CmmaConfig; -#[cube(launch)] +#[cube(launch_unchecked)] #[allow(unused_mut)] pub fn cmma_kernel( lhs: &Tensor, diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index bbbe24852..e76fcaf40 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -157,13 +157,15 @@ fn matmul_cmma_ref_no_check( let cube_dim = cmma_cube_dim(); let launch_config = CmmaLaunchConfig::default(); - cmma_kernel::launch::( - client, - cube_count, - cube_dim, - TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), - TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), - TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), - CmmaConfig::new(m, k, n, launch_config), - ); + unsafe { + cmma_kernel::launch_unchecked::( + client, + cube_count, + cube_dim, + TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), + TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), + TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), + CmmaConfig::new(m, k, n, launch_config), + ); + } } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/base.rs b/crates/cubecl-linalg/src/matmul/tiling2d/base.rs index 728ca6cf9..4418a3d56 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/base.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/base.rs @@ -6,7 +6,7 @@ use super::{block_loop::block_loop, config::CubeTiling2dConfig}; /// Most common tile size, the one used in most tests. pub(crate) const TILE_SIZE: usize = 4; -#[cube(launch)] +#[cube(launch_unchecked)] #[allow(unused_mut)] pub fn tiling2d_cube_kernel( lhs: &Tensor, diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index 5499d3e81..3f966d249 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -125,13 +125,15 @@ fn matmul_tiling_2d_ref_no_check( let cube_dim = tiling2d_cube_dim(&config); let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); - tiling2d_cube_kernel::launch::( - client, - cube_count, - cube_dim, - TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), - TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), - TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), - cube_config, - ); + unsafe { + tiling2d_cube_kernel::launch_unchecked::( + client, + cube_count, + cube_dim, + TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), + TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), + TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), + cube_config, + ); + } } diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs index 49ed4dcba..c4ec8647f 100644 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ b/crates/cubecl-macros/src/codegen_function/launch.rs @@ -13,13 +13,15 @@ struct Codegen { state_args: Vec, state_inputs: Vec<(Ident, syn::Type)>, state_outputs: Vec<(Ident, syn::Type)>, + unchecked: bool, } impl Codegen { - fn from_sig(sig: &syn::Signature) -> Self { + fn from_sig(sig: &syn::Signature, unchecked: bool) -> Self { let mut codegen = Codegen { name: snake_to_pascal_case(&sig.ident.to_string()), generics: sig.generics.clone(), + unchecked, ..Codegen::default() }; @@ -425,20 +427,30 @@ impl Codegen { } }; - quote::quote! { + let mut tokens = quote::quote! { #settings let kernel = #kernel; #body + }; - launcher.launch(cube_count, kernel, client); + if self.unchecked { + tokens.extend(quote::quote! { + launcher.launch_unchecked(cube_count, kernel, client); + }); + } else { + tokens.extend(quote::quote! { + launcher.launch(cube_count, kernel, client); + }); } + + tokens } } -pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { - let codegen = Codegen::from_sig(sig); +pub fn codegen_launch(sig: &syn::Signature, unchecked: bool) -> TokenStream { + let codegen = Codegen::from_sig(sig, unchecked); let ident = &sig.ident; @@ -453,13 +465,24 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { let (inputs, output) = (codegen.fn_inputs, codegen.fn_output); let doc = format!("Launch the kernel [{ident}()] on the given runtime."); + let maybe_unsafe = if unchecked { + quote::quote! {unsafe} + } else { + quote::quote! {} + }; + let launch_name = if unchecked { + quote::quote! { launch_unchecked} + } else { + quote::quote! { launch} + }; + quote::quote! { #kernel #compile #[allow(clippy::too_many_arguments)] #[doc = #doc] - pub fn launch #generics ( + pub #maybe_unsafe fn #launch_name #generics ( client: &ComputeClient, cube_count: CubeCount, cube_dim: CubeDim, diff --git a/crates/cubecl-macros/src/lib.rs b/crates/cubecl-macros/src/lib.rs index 82630ace6..f30e4a95d 100644 --- a/crates/cubecl-macros/src/lib.rs +++ b/crates/cubecl-macros/src/lib.rs @@ -45,6 +45,7 @@ pub fn module_derive_cube_type(input: TokenStream) -> TokenStream { struct SupportedAttributes { mode: CubeMode, launch: bool, + launch_unchecked: bool, } /// Derive macro for the module. @@ -69,7 +70,12 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { let mut variable_tracker = VariableAnalyzer::create_tracker(&func); - match codegen_cube(&func, &mut variable_tracker, attrs.launch) { + match codegen_cube( + &func, + &mut variable_tracker, + attrs.launch, + attrs.launch_unchecked, + ) { Ok(code) => code.into(), Err(err) => err.into(), } @@ -78,6 +84,7 @@ fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { fn parse_attributes(args: &Punctuated) -> SupportedAttributes { let mut mode = CubeMode::Default; let mut launch = false; + let mut launch_unchecked = false; for arg in args.iter() { match arg { @@ -90,7 +97,12 @@ fn parse_attributes(args: &Punctuated) -> SupportedAttributes { "launch" => { launch = true; } - _ => panic!("Attribute {ident} is not supported"), + "launch_unchecked" => { + launch_unchecked = true; + } + _ => { + panic!("Attribute {ident} is not supported") + } } } else { panic!("Only ident attribute supported"); @@ -101,7 +113,11 @@ fn parse_attributes(args: &Punctuated) -> SupportedAttributes { } } - SupportedAttributes { mode, launch } + SupportedAttributes { + mode, + launch, + launch_unchecked, + } } /// Generate the expanded version of a function marked with the cube macro @@ -109,6 +125,7 @@ fn codegen_cube( func: &syn::ItemFn, variable_tracker: &mut VariableTracker, launch: bool, + launch_unchecked: bool, ) -> Result { let signature = expand_sig( &func.sig, @@ -149,12 +166,18 @@ fn codegen_cube( "function " }; - let launch = if launch { - codegen_launch(&func.sig) + let mut launch = if launch { + codegen_launch(&func.sig, false) } else { quote::quote! {} }; + launch.extend(if launch_unchecked { + codegen_launch(&func.sig, true) + } else { + quote::quote! {} + }); + let mod_name = &func.sig.ident; let vis = &func.vis; let doc = format!("Module containing the expand {launch_doc}of {mod_name}."); diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index fcb1cdcb4..b7510784e 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -5,6 +5,16 @@ use crate::{ use alloc::vec::Vec; use cubecl_common::{reader::Reader, sync_type::SyncType}; +/// The kind of execution to be performed. +#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)] +pub enum KernelExecutionStrategy { + /// Checked kernels are safe. + #[default] + Checked, + /// Unchecked kernels are unsafe. + Unchecked, +} + /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { @@ -29,6 +39,7 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send kernel: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, + strategy: KernelExecutionStrategy, ); /// Perform some synchronization of commands on the server. diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs index aa9dcc25c..897408c31 100644 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ b/crates/cubecl-runtime/src/channel/cell.rs @@ -1,4 +1,4 @@ -use super::ComputeChannel; +use super::{ComputeChannel, KernelExecutionStrategy}; use crate::server::{Binding, ComputeServer, Handle}; use crate::storage::ComputeStorage; use alloc::sync::Arc; @@ -68,10 +68,11 @@ where kernel_description: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, + kind: KernelExecutionStrategy, ) { self.server .borrow_mut() - .execute(kernel_description, count, bindings) + .execute(kernel_description, count, bindings, kind) } fn sync(&self, sync_type: SyncType) { diff --git a/crates/cubecl-runtime/src/channel/mpsc.rs b/crates/cubecl-runtime/src/channel/mpsc.rs index 6d55a0302..a131e1d9e 100644 --- a/crates/cubecl-runtime/src/channel/mpsc.rs +++ b/crates/cubecl-runtime/src/channel/mpsc.rs @@ -1,7 +1,7 @@ use cubecl_common::{reader::Reader, sync_type::SyncType}; use std::{sync::Arc, thread}; -use super::ComputeChannel; +use super::{ComputeChannel, KernelExecutionStrategy}; use crate::{ server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, @@ -40,7 +40,11 @@ where Create(Vec, Callback>), Empty(usize, Callback>), ExecuteKernel( - (Server::Kernel, Server::DispatchOptions), + ( + Server::Kernel, + Server::DispatchOptions, + KernelExecutionStrategy, + ), Vec>, ), Sync(SyncType, Callback<()>), @@ -77,7 +81,7 @@ where callback.send(handle).await.unwrap(); } Message::ExecuteKernel(kernel, bindings) => { - server.execute(kernel.0, kernel.1, bindings); + server.execute(kernel.0, kernel.1, bindings, kernel.2); } Message::Sync(sync_type, callback) => { server.sync(sync_type); @@ -156,10 +160,11 @@ where kernel: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, + kind: KernelExecutionStrategy, ) { self.state .sender - .send_blocking(Message::ExecuteKernel((kernel, count), bindings)) + .send_blocking(Message::ExecuteKernel((kernel, count, kind), bindings)) .unwrap() } diff --git a/crates/cubecl-runtime/src/channel/mutex.rs b/crates/cubecl-runtime/src/channel/mutex.rs index 13f2e12b3..e63fd9a53 100644 --- a/crates/cubecl-runtime/src/channel/mutex.rs +++ b/crates/cubecl-runtime/src/channel/mutex.rs @@ -1,4 +1,4 @@ -use super::ComputeChannel; +use super::{ComputeChannel, KernelExecutionStrategy}; use crate::server::{Binding, ComputeServer, Handle}; use crate::storage::ComputeStorage; use alloc::sync::Arc; @@ -61,8 +61,9 @@ where kernel: Server::Kernel, count: Server::DispatchOptions, handles: Vec>, + kind: KernelExecutionStrategy, ) { - self.server.lock().execute(kernel, count, handles) + self.server.lock().execute(kernel, count, handles, kind) } fn sync(&self, sync_type: SyncType) { diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 4e2b6a396..db3693564 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -1,5 +1,5 @@ use crate::{ - channel::ComputeChannel, + channel::{ComputeChannel, KernelExecutionStrategy}, server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, }; @@ -77,7 +77,19 @@ where count: Server::DispatchOptions, bindings: Vec>, ) { - self.channel.execute(kernel, count, bindings) + self.channel + .execute(kernel, count, bindings, KernelExecutionStrategy::Checked) + } + + /// Executes the `kernel` over the given `bindings` without performing any bound checks. + pub unsafe fn execute_unchecked( + &self, + kernel: Server::Kernel, + count: Server::DispatchOptions, + bindings: Vec>, + ) { + self.channel + .execute(kernel, count, bindings, KernelExecutionStrategy::Unchecked) } /// Wait for the completion of every task in the server. diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 212461a11..5597c48ef 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -1,4 +1,5 @@ use crate::{ + channel::KernelExecutionStrategy, memory_management::{MemoryHandle, MemoryManagement}, storage::ComputeStorage, }; @@ -49,6 +50,7 @@ where kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, + kind: KernelExecutionStrategy, ); /// Wait for the completion of every task in the server. diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index 079ccb72b..5daf5280e 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use cubecl_common::{reader::reader_from_concrete, sync_type::SyncType}; use cubecl_runtime::{ + channel::KernelExecutionStrategy, memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, server::{Binding, ComputeServer, Handle}, storage::{BytesResource, BytesStorage}, @@ -58,6 +59,7 @@ where kernel: Self::Kernel, _count: Self::DispatchOptions, bindings: Vec>, + _strategy: KernelExecutionStrategy, ) { let mut resources = bindings .into_iter() diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index fc2f10cc5..c50e1964f 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -1,6 +1,7 @@ use super::{shader::ComputeShader, Item, SharedMemory}; use super::{LocalArray, Subgroup}; use crate::compiler::wgsl; +use cubecl_core::channel::KernelExecutionStrategy; use cubecl_core::ir as cube; /// Wgsl Compiler. @@ -33,7 +34,10 @@ impl core::fmt::Debug for WgslCompiler { impl cubecl_core::Compiler for WgslCompiler { type Representation = ComputeShader; - fn compile(shader: cube::KernelDefinition) -> Self::Representation { + fn compile( + shader: cube::KernelDefinition, + _strategy: KernelExecutionStrategy, + ) -> Self::Representation { let mut compiler = Self::default(); compiler.compile_shader(shader) } diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index f2dc72577..e4a639d22 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -3,7 +3,9 @@ use std::num::NonZeroU64; use super::WgpuStorage; use alloc::{borrow::Cow, sync::Arc}; use cubecl_common::{reader::Reader, sync_type::SyncType}; -use cubecl_core::{compute::DebugInformation, prelude::*, FeatureSet, KernelId}; +use cubecl_core::{ + channel::KernelExecutionStrategy, compute::DebugInformation, prelude::*, FeatureSet, KernelId, +}; use cubecl_runtime::{ debug::DebugLogger, memory_management::MemoryManagement, @@ -98,31 +100,51 @@ where self.tasks_count += 1; } - fn pipeline(&mut self, kernel: ::Kernel) -> Arc { - let kernel_id = kernel.id(); + fn pipeline( + &mut self, + kernel: ::Kernel, + strategy: KernelExecutionStrategy, + ) -> Arc { + let mut kernel_id = kernel.id(); + kernel_id.kind(strategy); if let Some(pipeline) = self.pipelines.get(&kernel_id) { return pipeline.clone(); } - let mut compile = kernel.compile(); + let mut compile = kernel.compile(strategy); if self.logger.is_activated() { compile.debug_info = Some(DebugInformation::new("wgsl", kernel_id.clone())); } let compile = self.logger.debug(compile); - let pipeline = self.compile_source(&compile.source); + let pipeline = self.compile_source(&compile.source, strategy); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); pipeline } - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); + fn compile_source( + &self, + source: &str, + strategy: KernelExecutionStrategy, + ) -> Arc { + let module = match strategy { + KernelExecutionStrategy::Checked => { + self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }) + } + KernelExecutionStrategy::Unchecked => unsafe { + self.device + .create_shader_module_unchecked(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }) + }, + }; Arc::new( self.device @@ -288,8 +310,9 @@ where kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, + kind: KernelExecutionStrategy, ) { - let pipeline = self.pipeline(kernel); + let pipeline = self.pipeline(kernel, kind); let group_layout = pipeline.get_bind_group_layout(0); let memory_handles = bindings From d901acdef8073785380e4ea8717ceac8b71975bb Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 1 Aug 2024 17:29:54 -0400 Subject: [PATCH 2/9] Some fixes --- crates/cubecl-core/src/compute/launcher.rs | 43 +- crates/cubecl-core/src/ir/processing.rs | 12 +- crates/cubecl-cuda/src/compiler/base.rs | 194 +++++-- .../src/matmul/tests/cmma/compute_loop.rs | 80 +-- .../matmul/tests/cmma/load_shared_memory.rs | 472 +++++++++--------- .../src/matmul/tests/cmma/write_output.rs | 156 +++--- .../src/matmul/tests/tiling2d/compute_loop.rs | 100 ++-- .../tests/tiling2d/load_shared_memory.rs | 282 ++++++----- .../src/matmul/tests/tiling2d/write_output.rs | 94 ++-- 9 files changed, 800 insertions(+), 633 deletions(-) diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 12f76d7a0..66ed56b71 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -188,21 +188,16 @@ impl TensorState { bindings.push(tensor.handle.clone().binding()); - let old_rank = if metadata.is_empty() { + if metadata.is_empty() { let rank = tensor.strides.len() as u32; metadata.push(rank); - None } else if tensor.strides.len() > metadata[0] as usize { - let old_rank = metadata[0]; let rank = tensor.strides.len() as u32; - Self::adjust_rank(metadata, bindings.len(), rank); - Some(old_rank) - } else { - None - }; + Self::adjust_rank(metadata, bindings.len() - 1, rank); + } - Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata); - Self::register_shape(tensor.shape, old_rank, metadata); + Self::register_strides(tensor.strides, tensor.shape, None, metadata); + Self::register_shape(tensor.shape, None, metadata); if R::require_array_lengths() { let len = calculate_num_elems_dyn_rank(tensor.shape); @@ -214,6 +209,7 @@ impl TensorState { let old_rank = metadata[0] as usize; let rank_diff = rank as usize - old_rank; let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered); + updated_metadata.push(rank); for pos in 0..num_registered { let stride_index = (pos * old_rank * 2) + 1; @@ -242,19 +238,14 @@ impl TensorState { ) { let old_rank = if let Some(old_rank) = old_rank { let rank = output[0]; - let rank_diff = old_rank - rank; - let padded_strides = if rank_diff > 0 { - shape - .iter() - .take(old_rank as usize) - .map(|a| a.to_u32().unwrap()) - .sum::() - } else { - 0 - }; + let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize; - for _ in 0..rank_diff { - output.push(padded_strides.to_u32().unwrap()); + if rank_diff > 0 { + let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::(); + + for _ in 0..rank_diff { + output.push(padded_strides); + } } old_rank as usize @@ -270,10 +261,12 @@ impl TensorState { fn register_shape(shape: &[T], old_rank: Option, output: &mut Vec) { let old_rank = if let Some(old_rank) = old_rank { let rank = output[0]; - let rank_diff = rank - old_rank; + let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize; - for _ in 0..rank_diff { - output.push(1); + if rank_diff > 0 { + for _ in 0..rank_diff { + output.push(1); + } } old_rank as usize diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index 12351820b..6a4a8e165 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -124,15 +124,15 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } - Operator::Index(op) => { - sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); - sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); - } Operator::Slice(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &op.out); sanitize_constant_scalar_ref_elem(&mut op.start, Elem::UInt); sanitize_constant_scalar_ref_elem(&mut op.end, Elem::UInt); } + Operator::Index(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); + sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); + } Operator::UncheckedIndex(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); @@ -142,8 +142,8 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } Operator::UncheckedIndexAssign(op) => { - sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); - sanitize_constant_scalar_ref_elem(&mut op.rhs, Elem::UInt); + sanitize_constant_scalar_ref_elem(&mut op.lhs, Elem::UInt); + sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } Operator::And(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs); diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 39f5b1fd1..73e162e7c 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -99,9 +99,9 @@ impl CudaCompiler { } } - fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec { + fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec { let mut instructions = Vec::new(); - let processing = value.process(); + let processing = scope.process(); for var in processing.variables { if let gpu::Variable::Slice { .. } = var { @@ -115,7 +115,7 @@ impl CudaCompiler { processing .operations .into_iter() - .for_each(|op| self.compile_operation(&mut instructions, op, value)); + .for_each(|op| self.compile_operation(&mut instructions, op, scope)); instructions } @@ -127,7 +127,7 @@ impl CudaCompiler { scope: &mut gpu::Scope, ) { match operation { - gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)), + gpu::Operation::Operator(op) => self.compile_instruction(op, instructions, scope), gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope), gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)), gpu::Operation::Branch(val) => self.compile_branch(instructions, val), @@ -327,57 +327,140 @@ impl CudaCompiler { } } - fn compile_instruction(&mut self, value: gpu::Operator) -> Instruction { + fn compile_instruction( + &mut self, + value: gpu::Operator, + instructions: &mut Vec, + scope: &mut gpu::Scope, + ) { match value { - gpu::Operator::Add(op) => Instruction::Add(self.compile_binary(op)), - gpu::Operator::Mul(op) => Instruction::Mul(self.compile_binary(op)), - gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)), - gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)), - gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)), - gpu::Operator::Slice(op) => Instruction::Slice { + gpu::Operator::Add(op) => instructions.push(Instruction::Add(self.compile_binary(op))), + gpu::Operator::Mul(op) => instructions.push(Instruction::Mul(self.compile_binary(op))), + gpu::Operator::Div(op) => instructions.push(Instruction::Div(self.compile_binary(op))), + gpu::Operator::Sub(op) => instructions.push(Instruction::Sub(self.compile_binary(op))), + gpu::Operator::Assign(op) => { + instructions.push(Instruction::Assign(self.compile_unary(op))) + } + gpu::Operator::Slice(op) => instructions.push(Instruction::Slice { input: self.compile_variable(op.input), start: self.compile_variable(op.start), end: self.compile_variable(op.end), out: self.compile_variable(op.out), - }, - gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)), - gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)), - gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)), + }), + gpu::Operator::Index(op) => { + if let KernelExecutionStrategy::Checked = self.strategy { + let has_len = match op.lhs { + gpu::Variable::GlobalInputArray { .. } => true, + gpu::Variable::GlobalOutputArray { .. } => true, + gpu::Variable::Slice { .. } => true, + _ => false, + }; + if has_len { + self.compile_procedure( + instructions, + gpu::Procedure::CheckedIndex(gpu::CheckedIndex { + lhs: op.lhs, + rhs: op.rhs, + out: op.out, + }), + scope, + ); + + return; + } + }; + + instructions.push(Instruction::Index(self.compile_binary(op))); + } + gpu::Operator::UncheckedIndex(op) => { + instructions.push(Instruction::Index(self.compile_binary(op))) + } + gpu::Operator::IndexAssign(op) => { + if let KernelExecutionStrategy::Checked = self.strategy { + let has_len = match op.out { + gpu::Variable::GlobalInputArray { .. } => true, + gpu::Variable::GlobalOutputArray { .. } => true, + gpu::Variable::Slice { .. } => true, + _ => false, + }; + + if has_len { + self.compile_procedure( + instructions, + gpu::Procedure::CheckedIndexAssign(gpu::CheckedIndexAssign { + lhs: op.lhs, + rhs: op.rhs, + out: op.out, + }), + scope, + ); + return; + } + }; + + instructions.push(Instruction::IndexAssign(self.compile_binary(op))); + } gpu::Operator::UncheckedIndexAssign(op) => { - Instruction::IndexAssign(self.compile_binary(op)) - } - gpu::Operator::Modulo(op) => Instruction::Modulo(self.compile_binary(op)), - gpu::Operator::Equal(op) => Instruction::Equal(self.compile_binary(op)), - gpu::Operator::Lower(op) => Instruction::Lower(self.compile_binary(op)), - gpu::Operator::Greater(op) => Instruction::Greater(self.compile_binary(op)), - gpu::Operator::LowerEqual(op) => Instruction::LowerEqual(self.compile_binary(op)), - gpu::Operator::GreaterEqual(op) => Instruction::GreaterEqual(self.compile_binary(op)), - gpu::Operator::Abs(op) => Instruction::Abs(self.compile_unary(op)), - gpu::Operator::Exp(op) => Instruction::Exp(self.compile_unary(op)), - gpu::Operator::Log(op) => Instruction::Log(self.compile_unary(op)), - gpu::Operator::Log1p(op) => Instruction::Log1p(self.compile_unary(op)), - gpu::Operator::Cos(op) => Instruction::Cos(self.compile_unary(op)), - gpu::Operator::Sin(op) => Instruction::Sin(self.compile_unary(op)), - gpu::Operator::Tanh(op) => Instruction::Tanh(self.compile_unary(op)), - gpu::Operator::Powf(op) => Instruction::Powf(self.compile_binary(op)), - gpu::Operator::Sqrt(op) => Instruction::Sqrt(self.compile_unary(op)), - gpu::Operator::Erf(op) => Instruction::Erf(self.compile_unary(op)), - gpu::Operator::And(op) => Instruction::And(self.compile_binary(op)), - gpu::Operator::Or(op) => Instruction::Or(self.compile_binary(op)), - gpu::Operator::Not(op) => Instruction::Not(self.compile_unary(op)), - gpu::Operator::Max(op) => Instruction::Max(self.compile_binary(op)), - gpu::Operator::Min(op) => Instruction::Min(self.compile_binary(op)), - gpu::Operator::NotEqual(op) => Instruction::NotEqual(self.compile_binary(op)), - gpu::Operator::BitwiseAnd(op) => Instruction::BitwiseAnd(self.compile_binary(op)), - gpu::Operator::BitwiseXor(op) => Instruction::BitwiseXor(self.compile_binary(op)), - gpu::Operator::ShiftLeft(op) => Instruction::ShiftLeft(self.compile_binary(op)), - gpu::Operator::ShiftRight(op) => Instruction::ShiftRight(self.compile_binary(op)), - gpu::Operator::Clamp(op) => Instruction::Clamp { + instructions.push(Instruction::IndexAssign(self.compile_binary(op))) + } + gpu::Operator::Modulo(op) => { + instructions.push(Instruction::Modulo(self.compile_binary(op))) + } + gpu::Operator::Equal(op) => { + instructions.push(Instruction::Equal(self.compile_binary(op))) + } + gpu::Operator::Lower(op) => { + instructions.push(Instruction::Lower(self.compile_binary(op))) + } + gpu::Operator::Greater(op) => { + instructions.push(Instruction::Greater(self.compile_binary(op))) + } + gpu::Operator::LowerEqual(op) => { + instructions.push(Instruction::LowerEqual(self.compile_binary(op))) + } + gpu::Operator::GreaterEqual(op) => { + instructions.push(Instruction::GreaterEqual(self.compile_binary(op))) + } + gpu::Operator::Abs(op) => instructions.push(Instruction::Abs(self.compile_unary(op))), + gpu::Operator::Exp(op) => instructions.push(Instruction::Exp(self.compile_unary(op))), + gpu::Operator::Log(op) => instructions.push(Instruction::Log(self.compile_unary(op))), + gpu::Operator::Log1p(op) => { + instructions.push(Instruction::Log1p(self.compile_unary(op))) + } + gpu::Operator::Cos(op) => instructions.push(Instruction::Cos(self.compile_unary(op))), + gpu::Operator::Sin(op) => instructions.push(Instruction::Sin(self.compile_unary(op))), + gpu::Operator::Tanh(op) => instructions.push(Instruction::Tanh(self.compile_unary(op))), + gpu::Operator::Powf(op) => { + instructions.push(Instruction::Powf(self.compile_binary(op))) + } + gpu::Operator::Sqrt(op) => instructions.push(Instruction::Sqrt(self.compile_unary(op))), + gpu::Operator::Erf(op) => instructions.push(Instruction::Erf(self.compile_unary(op))), + gpu::Operator::And(op) => instructions.push(Instruction::And(self.compile_binary(op))), + gpu::Operator::Or(op) => instructions.push(Instruction::Or(self.compile_binary(op))), + gpu::Operator::Not(op) => instructions.push(Instruction::Not(self.compile_unary(op))), + gpu::Operator::Max(op) => instructions.push(Instruction::Max(self.compile_binary(op))), + gpu::Operator::Min(op) => instructions.push(Instruction::Min(self.compile_binary(op))), + gpu::Operator::NotEqual(op) => { + instructions.push(Instruction::NotEqual(self.compile_binary(op))) + } + gpu::Operator::BitwiseAnd(op) => { + instructions.push(Instruction::BitwiseAnd(self.compile_binary(op))) + } + gpu::Operator::BitwiseXor(op) => { + instructions.push(Instruction::BitwiseXor(self.compile_binary(op))) + } + gpu::Operator::ShiftLeft(op) => { + instructions.push(Instruction::ShiftLeft(self.compile_binary(op))) + } + gpu::Operator::ShiftRight(op) => { + instructions.push(Instruction::ShiftRight(self.compile_binary(op))) + } + gpu::Operator::Clamp(op) => instructions.push(Instruction::Clamp { input: self.compile_variable(op.input), min_value: self.compile_variable(op.min_value), max_value: self.compile_variable(op.max_value), out: self.compile_variable(op.out), - }, + }), gpu::Operator::Recip(op) => { let elem = op.input.item().elem(); let lhs = match elem { @@ -386,22 +469,27 @@ impl CudaCompiler { gpu::Elem::UInt => ConstantScalarValue::UInt(1), gpu::Elem::Bool => ConstantScalarValue::Bool(true), }; - Instruction::Div(super::BinaryInstruction { + + instructions.push(Instruction::Div(super::BinaryInstruction { lhs: super::Variable::ConstantScalar(lhs, self.compile_elem(elem)), rhs: self.compile_variable(op.input), out: self.compile_variable(op.out), - }) + })) + } + gpu::Operator::Floor(op) => { + instructions.push(Instruction::Floor(self.compile_unary(op))) } - gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)), - gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)), - gpu::Operator::Remainder(_op) => todo!(), - gpu::Operator::Fma(op) => Instruction::Fma { + gpu::Operator::Ceil(op) => instructions.push(Instruction::Ceil(self.compile_unary(op))), + gpu::Operator::Remainder(op) => { + instructions.push(Instruction::Modulo(self.compile_binary(op))) + } + gpu::Operator::Fma(op) => instructions.push(Instruction::Fma { a: self.compile_variable(op.a), b: self.compile_variable(op.b), c: self.compile_variable(op.c), out: self.compile_variable(op.out), - }, - } + }), + }; } fn compile_binary(&mut self, value: gpu::BinaryOperator) -> super::BinaryInstruction { diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index de836ee8d..a0ed7ef21 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -10,7 +10,7 @@ use crate::matmul::tests::test_utils::{ assert_equals, cmma_available, create_empty, range_tensor_f16, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn compute_loop_test( lhs_tensor: &Tensor, rhs_tensor: &Tensor, @@ -84,18 +84,20 @@ pub fn compute_loop_k_test(device: &R::Device) { unroll: false, }; - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, m * n), - UInt::new(m as u32), - UInt::new(k as u32), - UInt::new(n as u32), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::new(&results, m * n), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), + config, + ); + }; let expected = &[ 1610496., 1614832., 1619168., 1623504., 1627840., 1632176., 1636512., 1640848., 1645184., @@ -160,18 +162,20 @@ pub fn compute_loop_warp_test(device: &R::Device) { unroll: false, }; - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, m * n), - UInt::new(m as u32), - UInt::new(k as u32), - UInt::new(n as u32), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::new(&results, m * n), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), + config, + ); + }; let expected = &[ 1610496., 1614832., 1619168., 1623504., 1627840., 1632176., 1636512., 1640848., 1645184., @@ -265,18 +269,20 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De unroll: false, }; - compute_loop_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, m * n), - UInt::new(m as u32), - UInt::new(k as u32), - UInt::new(n as u32), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::new(&results, m * n), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), + config, + ); + }; let expected = &[ 1610496.0, 1614832.0, 1619168.0, 1623504.0, 1627840.0, 1632176.0, 1636512.0, 1640848.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index 08bed0ed0..4894bca1e 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -8,7 +8,7 @@ use crate::matmul::{ tests::test_utils::range_tensor, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn load_lhs_test( lhs_tensor: &Tensor, lhs_sm_arr: &mut Array, @@ -41,7 +41,7 @@ fn load_lhs_test( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn load_rhs_test( rhs_tensor: &Tensor, rhs_sm_arr: &mut Array, @@ -93,23 +93,25 @@ pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -150,23 +152,25 @@ pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + ), + ArrayArg::new(&rhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -207,23 +211,25 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, @@ -269,23 +275,25 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device unroll: false, }; - load_lhs_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(12), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(12), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, @@ -329,23 +337,25 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(12), - ScalarArg::new(12), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(12), + ScalarArg::new(12), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, @@ -389,23 +399,25 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(12), - ScalarArg::new(12), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(12), + ScalarArg::new(12), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, @@ -448,23 +460,25 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + ), + ArrayArg::new(&rhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, @@ -510,23 +524,25 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., 81., @@ -571,23 +587,25 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + ), + ArrayArg::new(&rhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., @@ -635,23 +653,25 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., @@ -699,23 +719,25 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(0), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + ), + ArrayArg::new(&rhs_sm, 64 * 32), + ScalarArg::new(0), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., 81., @@ -760,23 +782,25 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { unroll: false, }; - load_lhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &lhs_tensor.handle, - &lhs_tensor.strides, - &lhs_tensor.shape, - ), - ArrayArg::new(&lhs_sm, 64 * 32), - ScalarArg::new(32), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_lhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &lhs_tensor.handle, + &lhs_tensor.strides, + &lhs_tensor.shape, + ), + ArrayArg::new(&lhs_sm, 64 * 32), + ScalarArg::new(32), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 96., 97., @@ -821,23 +845,25 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { unroll: false, }; - load_rhs_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - 4, - &rhs_tensor.handle, - &rhs_tensor.strides, - &rhs_tensor.shape, - ), - ArrayArg::new(&rhs_sm, 64 * 32), - ScalarArg::new(32), - ScalarArg::new(64), - ScalarArg::new(64), - ScalarArg::new(64), - config, - ); + unsafe { + load_rhs_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + 4, + &rhs_tensor.handle, + &rhs_tensor.strides, + &rhs_tensor.shape, + ), + ArrayArg::new(&rhs_sm, 64 * 32), + ScalarArg::new(32), + ScalarArg::new(64), + ScalarArg::new(64), + ScalarArg::new(64), + config, + ); + }; let expected = &[ 2048., 2049., 2050., 2051., 2052., 2053., 2054., 2055., 2056., 2057., 2058., 2059., 2060., diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index d593eb013..597075172 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -8,7 +8,7 @@ use crate::matmul::{ tests::test_utils::range_tensor, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn write_output_test( out: &mut Tensor, acc_sm_arr: &mut Array, @@ -60,16 +60,18 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 256.0, @@ -126,16 +128,18 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -202,16 +206,18 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -273,16 +279,18 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -344,16 +352,18 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -411,16 +421,18 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, @@ -527,16 +539,18 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) unroll: false, }; - write_output_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&acc_sm.handle, 64 * 64), - ScalarArg::new(m as u32), - ScalarArg::new(n as u32), - config, - ); + unsafe { + write_output_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&acc_sm.handle, 64 * 64), + ScalarArg::new(m as u32), + ScalarArg::new(n as u32), + config, + ); + }; let expected = &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index 21447b0cd..6754ada1f 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -13,7 +13,7 @@ use crate::matmul::{ }, }; -#[cube(launch)] +#[cube(launch_unchecked)] #[allow(unused_mut)] fn tile_outer_product_test( register_m: Array, @@ -50,15 +50,17 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) const SOME_DIM: usize = 12; let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); - tile_outer_product_test::launch::( - &client, - cube_count, - cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), - config, - ); + unsafe { + tile_outer_product_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + ArrayArg::new(®ister_m, 4), + ArrayArg::new(®ister_n, 4), + ArrayArg::new(&results, 16), + config, + ); + }; let expected = &[ 64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0, @@ -67,7 +69,7 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) assert_equals::(&client, results, expected); } -#[cube(launch)] +#[cube(launch_unchecked)] fn compute_loop_test( lhs: &Tensor, rhs: &Tensor, @@ -124,15 +126,17 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { const SOME_DIM: usize = 12; let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); - tile_outer_product_test::launch::( - &client, - cube_count, - cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), - config, - ); + unsafe { + tile_outer_product_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + ArrayArg::new(®ister_m, 4), + ArrayArg::new(®ister_n, 4), + ArrayArg::new(&results, 16), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, @@ -152,19 +156,21 @@ pub fn compute_loop_unit_test(device: &R::Device) { const SOME_DIM: usize = 12; let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ScalarArg::new(0), - ScalarArg::new(0), - ArrayArg::new(&results, 16), - UInt::new(16), - UInt::new(16), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + ScalarArg::new(0), + ScalarArg::new(0), + ArrayArg::new(&results, 16), + UInt::new(16), + UInt::new(16), + config, + ); + }; let expected = &[ 8960.0, 9184.0, 9408.0, 9632.0, 9184.0, 9416.0, 9648.0, 9880.0, 9408.0, 9648.0, 9888.0, @@ -184,19 +190,21 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { let config = make_tiling2d_config(4, 8, 4); - compute_loop_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ScalarArg::new(4), - ScalarArg::new(4), - ArrayArg::new(&results, 16), - UInt::new(8), - UInt::new(8), - config, - ); + unsafe { + compute_loop_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + ScalarArg::new(4), + ScalarArg::new(4), + ArrayArg::new(&results, 16), + UInt::new(8), + UInt::new(8), + config, + ); + }; let expected = &[ 1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index b8059b522..f91d0d9b0 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -15,7 +15,7 @@ use crate::matmul::{ }, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn load_tensor_test( tensor: &Tensor, sm_out: &mut Array, @@ -80,7 +80,7 @@ fn load_tensor_test( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn load_tensor_permuted_test( tensor: &Tensor, sm_out: &mut Array, @@ -147,7 +147,7 @@ fn load_tensor_permuted_test( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn load_tensor_multiple_tiles_test( tensor: &Tensor, sm_out: &mut Array, @@ -222,18 +222,20 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -255,21 +257,23 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic let config = make_tiling2d_config(5, 1, 1); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized( - vectorization_factor as u8, - &lhs.handle, - &lhs.strides, - &lhs.shape, - ), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - true, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized( + vectorization_factor as u8, + &lhs.handle, + &lhs.strides, + &lhs.shape, + ), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(0), + config, + true, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -290,16 +294,18 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - true, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(0), + config, + true, + ); + }; let expected = &[ 0.0, 8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0, 57.0, @@ -321,16 +327,18 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 16); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 8.0, 24.0, 40.0, 56.0, 72.0, 88.0, 104.0, 120.0, 9.0, 25.0, 41.0, 57.0, 73.0, 89.0, 105.0, @@ -352,18 +360,20 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 16, 16); - load_tensor_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -384,16 +394,18 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - false, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(0), + config, + false, + ); + }; let expected = &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, @@ -415,16 +427,18 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_multiple_tiles_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_multiple_tiles_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, @@ -446,18 +460,20 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_permuted_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -479,18 +495,20 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { let config = make_tiling2d_config(m, k, 8); - load_tensor_permuted_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -511,18 +529,20 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); - load_tensor_permuted_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -544,18 +564,20 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic let config = make_tiling2d_config(8, k, n); - load_tensor_permuted_test::launch::( - &client, - cube_count, - cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); + unsafe { + load_tensor_permuted_test::launch_unchecked::( + &client, + cube_count, + cube_dim, + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index f0a9e5939..852d29c22 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -14,7 +14,7 @@ use crate::matmul::{ }, }; -#[cube(launch)] +#[cube(launch_unchecked)] fn write_to_output_test( out: &mut Tensor, results: &mut Array, @@ -35,7 +35,7 @@ fn write_to_output_test( write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); } -#[cube(launch)] +#[cube(launch_unchecked)] fn write_results_to_output_out_of_bounds_test( out: &mut Tensor, results: &mut Array, @@ -66,14 +66,16 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { let config = make_tiling2d_config(6, 8, 8); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&tile.handle, 16), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -93,14 +95,16 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 4); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&tile.handle, 16), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -120,14 +124,16 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & let config = make_tiling2d_config(8, 8, 8); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&tile.handle, 16), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -149,14 +155,16 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); - write_to_output_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&tile.handle, 16), - config, - ); + unsafe { + write_to_output_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&tile.handle, 16), + config, + ); + }; let expected = &[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -178,14 +186,16 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De let config = make_tiling2d_config(5, 8, 1); - write_results_to_output_out_of_bounds_test::launch::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape), - ArrayArg::new(&results.handle, 16), - config, - ); + unsafe { + write_results_to_output_out_of_bounds_test::launch_unchecked::( + &R::client(device), + cube_count, + cube_dim, + TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape), + ArrayArg::new(&results.handle, 16), + config, + ); + }; let expected = &[0.0, 1.0, 2.0, 3.0, 0.0]; assert_equals::(&client, out.handle, expected); From 4414af91b6f3ed481320175e0ffce77b46d09813 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 2 Aug 2024 13:54:01 -0400 Subject: [PATCH 3/9] Fix boolean cast in cuda --- crates/cubecl-cuda/src/compiler/binary.rs | 96 +++++++++++++++++----- crates/cubecl-cuda/src/compiler/element.rs | 9 ++ 2 files changed, 85 insertions(+), 20 deletions(-) diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs index b16f24ab4..bd73c2bc2 100644 --- a/crates/cubecl-cuda/src/compiler/binary.rs +++ b/crates/cubecl-cuda/src/compiler/binary.rs @@ -1,9 +1,9 @@ use super::{Component, Elem, Variable}; -use std::fmt::Display; +use std::fmt::{Display, Formatter}; pub trait Binary { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -13,7 +13,7 @@ pub trait Binary { } fn format_scalar( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: Lhs, rhs: Rhs, out: Out, @@ -25,7 +25,7 @@ pub trait Binary { Out: Component; fn unroll_vec( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -140,11 +140,11 @@ pub struct Index; impl Binary for IndexAssign { fn format_scalar( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: Lhs, rhs: Rhs, out: Out, - _elem: Elem, + elem: Elem, ) -> std::fmt::Result where Lhs: Component, @@ -154,10 +154,10 @@ impl Binary for IndexAssign { let item_out = out.item(); let item_rhs = rhs.item(); - if item_out.vectorization != item_rhs.vectorization { + let format_vec = |f: &mut Formatter<'_>, cast: bool| { let is_vec_native = item_out.is_vec_native(); f.write_str("{\n")?; - let var = "scalar_broadcasted"; + let var = "broadcasted"; f.write_fmt(format_args!("{item_out} {var};\n"))?; for i in 0..item_out.vectorization { if is_vec_native { @@ -168,14 +168,33 @@ impl Binary for IndexAssign { 3 => 'w', _ => panic!("Invalid"), }; - f.write_fmt(format_args!("{var}.{char} = {rhs};\n"))?; + if cast { + f.write_fmt(format_args!("{var}.{char} = ({}){};\n", elem, rhs.index(i)))?; + } else { + f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?; + } } else { - f.write_fmt(format_args!("{var}.i_{i} = {rhs};\n"))?; + if cast { + f.write_fmt(format_args!("{var}.i_{i} = ({}){};\n", elem, rhs.index(i)))?; + } else { + f.write_fmt(format_args!("{var}.i_{i} = {};\n", rhs.index(i)))?; + } } } f.write_fmt(format_args!("{out}[{lhs}] = {var};\n"))?; f.write_str("}")?; + Ok(()) + }; + + if item_out.vectorization != item_rhs.vectorization { + format_vec(f, item_out != item_rhs) + } else if elem != item_rhs.elem { + if item_out.vectorization > 1 { + format_vec(f, true)?; + } else { + f.write_fmt(format_args!("{out}[{lhs}] = ({elem}){rhs};\n"))?; + } Ok(()) } else { f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n")) @@ -183,7 +202,7 @@ impl Binary for IndexAssign { } fn unroll_vec( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -195,8 +214,8 @@ impl Binary for IndexAssign { } for i in 0..index { - let lhsi = lhs.index(i, false); - let rhsi = rhs.index(i, false); + let lhsi = lhs.index(i, lhs.item().is_optimized()); + let rhsi = rhs.index(i, rhs.item().is_optimized()); Self::format_scalar(f, lhsi, rhsi, *out, elem)?; } @@ -204,7 +223,7 @@ impl Binary for IndexAssign { } fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -227,7 +246,7 @@ impl Binary for IndexAssign { impl Binary for Index { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -245,18 +264,55 @@ impl Binary for Index { } fn format_scalar( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: Lhs, rhs: Rhs, out: Out, - _elem: Elem, + elem: Elem, ) -> std::fmt::Result where Lhs: Component, Rhs: Component, Out: Component, { - f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n")) + let item_out = out.item(); + let item_lhs = lhs.item(); + + let format_vec = |f: &mut Formatter<'_>| { + let is_vec_native = item_out.is_vec_native(); + f.write_str("{\n")?; + let var = "broadcasted"; + f.write_fmt(format_args!("{item_out} {var};\n"))?; + for i in 0..item_out.vectorization { + if is_vec_native { + let char = match i { + 0 => 'x', + 1 => 'y', + 2 => 'z', + 3 => 'w', + _ => panic!("Invalid"), + }; + f.write_fmt(format_args!("{var}.{char} = {elem}({lhs}[{rhs}].i_{i});\n",))?; + } else { + f.write_fmt(format_args!("{var}.i_{i} = {elem}({lhs}[{rhs}].i_{i});\n",))?; + } + } + f.write_fmt(format_args!("{out} = {var};\n"))?; + f.write_str("}")?; + + Ok(()) + }; + + if elem != item_lhs.elem { + if item_out.vectorization > 1 { + format_vec(f)?; + } else { + f.write_fmt(format_args!("{out} = ({elem}){lhs}[{rhs}];\n"))?; + } + Ok(()) + } else { + f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n")) + } } } @@ -285,7 +341,7 @@ struct IndexAssignVector; impl IndexVector { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, @@ -307,7 +363,7 @@ impl IndexVector { impl IndexAssignVector { fn format( - f: &mut std::fmt::Formatter<'_>, + f: &mut Formatter<'_>, lhs: &Variable, rhs: &Variable, out: &Variable, diff --git a/crates/cubecl-cuda/src/compiler/element.rs b/crates/cubecl-cuda/src/compiler/element.rs index 22ce1ec16..f1570105b 100644 --- a/crates/cubecl-cuda/src/compiler/element.rs +++ b/crates/cubecl-cuda/src/compiler/element.rs @@ -54,6 +54,7 @@ impl Display for Item { pub trait Component: Display { fn item(&self) -> Item; + fn index(&self, index: usize) -> IndexedVariable; fn elem(&self) -> Elem { *self.item().elem() } @@ -63,8 +64,16 @@ impl Component for IndexedVariable { fn item(&self) -> Item { self.var.item() } + + fn index(&self, index: usize) -> IndexedVariable { + self.var.index(index, self.var.is_optimized()) + } } impl Component for Variable { + fn index(&self, index: usize) -> IndexedVariable { + self.index(index, self.is_optimized()) + } + fn item(&self) -> Item { match self { Variable::GlobalInputArray(_, e) => *e, From f1fbc5b4e97b418b209c3a74ea493d10b6c0bb0d Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 3 Aug 2024 12:39:07 -0400 Subject: [PATCH 4/9] Remove feature flag --- crates/cubecl-wgpu/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index be4ecfe2d..88bcbeb74 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -16,7 +16,6 @@ default = [ "cubecl-common/default", "cubecl-core/default", ] -autotune = [] std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] [dependencies] From 077272c6b31e48528c127bf2a6f6a706a4db9744 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Aug 2024 11:57:30 -0400 Subject: [PATCH 5/9] Update new API --- crates/cubecl-core/src/codegen/compiler.rs | 5 +- crates/cubecl-core/src/compute/kernel.rs | 12 ++-- .../src/frontend/element/tensor.rs | 56 ++++++++++++------- crates/cubecl-core/src/id.rs | 12 ++-- crates/cubecl-core/src/runtime.rs | 1 + .../cubecl-core/src/runtime_tests/subcube.rs | 12 ++-- crates/cubecl-cuda/src/compiler/base.rs | 10 ++-- crates/cubecl-cuda/src/compute/server.rs | 16 +++--- .../cubecl-linalg/src/matmul/cmma/launch.rs | 6 +- .../src/matmul/tests/cmma/compute_loop.rs | 12 ++-- .../matmul/tests/cmma/load_shared_memory.rs | 52 ++++++++--------- .../src/matmul/tests/cmma/write_output.rs | 14 ++--- .../src/matmul/tests/tiling2d/compute_loop.rs | 8 +-- .../tests/tiling2d/load_shared_memory.rs | 24 ++++---- .../src/matmul/tests/tiling2d/write_output.rs | 10 ++-- .../src/matmul/tiling2d/launch.rs | 6 +- crates/cubecl-linalg/src/tensor/contiguous.rs | 14 +---- .../src/{compute.rs => base.rs} | 10 ++++ crates/cubecl-runtime/src/channel/base.rs | 15 +---- crates/cubecl-runtime/src/channel/cell.rs | 7 ++- crates/cubecl-runtime/src/channel/mpsc.rs | 17 +++--- crates/cubecl-runtime/src/channel/mutex.rs | 7 ++- crates/cubecl-runtime/src/client.rs | 11 ++-- crates/cubecl-runtime/src/lib.rs | 4 +- crates/cubecl-runtime/src/server.rs | 6 +- crates/cubecl-runtime/tests/dummy/server.rs | 6 +- .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 7 +-- crates/cubecl-wgpu/src/compute/server.rs | 39 ++++++------- crates/cubecl/benches/unary.rs | 6 +- crates/cubecl/src/runtime.rs | 0 30 files changed, 204 insertions(+), 201 deletions(-) rename crates/cubecl-runtime/src/{compute.rs => base.rs} (92%) delete mode 100644 crates/cubecl/src/runtime.rs diff --git a/crates/cubecl-core/src/codegen/compiler.rs b/crates/cubecl-core/src/codegen/compiler.rs index e7728f7cb..2370dd79c 100644 --- a/crates/cubecl-core/src/codegen/compiler.rs +++ b/crates/cubecl-core/src/codegen/compiler.rs @@ -1,6 +1,5 @@ -use cubecl_runtime::channel::KernelExecutionStrategy; - use crate::ir::{Elem, KernelDefinition}; +use cubecl_runtime::ExecutionMode; use std::fmt::Display; /// Trait for compiled code representation @@ -15,7 +14,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { type Representation: CompilerRepresentation; /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation. - fn compile(kernel: KernelDefinition, kind: KernelExecutionStrategy) -> Self::Representation; + fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: Elem) -> usize; /// The maximal size of a shared memory diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index be07a5a22..d0031b3a8 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -6,8 +6,8 @@ use std::{ use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId}; use alloc::sync::Arc; use cubecl_runtime::{ - channel::KernelExecutionStrategy, server::{Binding, ComputeServer}, + ExecutionMode, }; /// A kernel, compiled in the target language @@ -160,7 +160,7 @@ pub trait CubeTask: Send + Sync { /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> KernelId; /// Compile the kernel into source - fn compile(&self, kind: KernelExecutionStrategy) -> CompiledKernel; + fn compile(&self, mode: ExecutionMode) -> CompiledKernel; } /// Wraps a [kernel](Kernel) to create a [cube task](CubeTask). @@ -171,10 +171,10 @@ pub struct KernelTask { } impl CubeTask for KernelTask { - fn compile(&self, strategy: KernelExecutionStrategy) -> CompiledKernel { + fn compile(&self, mode: ExecutionMode) -> CompiledKernel { let gpu_ir = self.kernel_definition.define(); let cube_dim = gpu_ir.cube_dim; - let lower_level_ir = C::compile(gpu_ir, strategy); + let lower_level_ir = C::compile(gpu_ir, mode); let shared_mem_bytes = lower_level_ir.shared_memory_size(); let source = lower_level_ir.to_string(); @@ -193,7 +193,7 @@ impl CubeTask for KernelTask { } impl CubeTask for Arc { - fn compile(&self, kind: KernelExecutionStrategy) -> CompiledKernel { + fn compile(&self, kind: ExecutionMode) -> CompiledKernel { self.as_ref().compile(kind) } @@ -203,7 +203,7 @@ impl CubeTask for Arc { } impl CubeTask for Box { - fn compile(&self, kind: KernelExecutionStrategy) -> CompiledKernel { + fn compile(&self, kind: ExecutionMode) -> CompiledKernel { self.as_ref().compile(kind) } diff --git a/crates/cubecl-core/src/frontend/element/tensor.rs b/crates/cubecl-core/src/frontend/element/tensor.rs index cb6e0dfeb..02600310e 100644 --- a/crates/cubecl-core/src/frontend/element/tensor.rs +++ b/crates/cubecl-core/src/frontend/element/tensor.rs @@ -52,13 +52,36 @@ impl LaunchArg for Tensor { /// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle), /// the strides and the shape. -#[derive(new)] pub struct TensorHandleRef<'a, R: Runtime> { pub handle: &'a cubecl_runtime::server::Handle, pub strides: &'a [usize], pub shape: &'a [usize], } +impl<'a, R: Runtime> TensorHandleRef<'a, R> { + /// Convert the handle into a [tensor argument](TensorArg). + pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + unsafe { TensorArg::from_raw_parts(self.handle, self.strides, self.shape, vectorisation) } + } + /// Create a handle from raw parts. + /// + /// # Safety + /// + /// If you provide wrong strides or shapes, it might create undefined behavior caused by + /// out of bound reads and writes. + pub unsafe fn from_raw_parts( + handle: &'a cubecl_runtime::server::Handle, + strides: &'a [usize], + shape: &'a [usize], + ) -> Self { + Self { + handle, + strides, + shape, + } + } +} + /// Argument to be used for [tensors](Tensor) passed as arguments to kernels. pub enum TensorArg<'a, R: Runtime> { /// The tensor is passed with a tensor handle. @@ -76,32 +99,27 @@ pub enum TensorArg<'a, R: Runtime> { } impl<'a, R: Runtime> TensorArg<'a, R> { - /// Create a new tensor argument. + /// Create a new tensor argument specified with its vectorization factor. + /// + /// # Safety /// - /// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization - /// factor of 1. - pub fn new( + /// If you provide wrong strides or shapes, it might create undefined behavior caused by + /// out of bound reads and writes. + pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, strides: &'a [usize], shape: &'a [usize], - ) -> Self { - Self::Handle { - handle: TensorHandleRef::new(handle, strides, shape), - vectorization_factor: 1, - } - } - /// Create a new tensor argument specified with its vectorization factor. - pub fn vectorized( factor: u8, - handle: &'a cubecl_runtime::server::Handle, - strides: &'a [usize], - shape: &'a [usize], ) -> Self { - Self::Handle { - handle: TensorHandleRef::new(handle, strides, shape), - vectorization_factor: factor, + unsafe { + Self::Handle { + handle: TensorHandleRef::from_raw_parts(handle, strides, shape), + vectorization_factor: factor, + } } } + + /// Create an alias argument. pub fn alias(position: usize) -> Self { Self::Alias { input_pos: position, diff --git a/crates/cubecl-core/src/id.rs b/crates/cubecl-core/src/id.rs index a96a4b10d..b8607dcbe 100644 --- a/crates/cubecl-core/src/id.rs +++ b/crates/cubecl-core/src/id.rs @@ -1,4 +1,4 @@ -use cubecl_runtime::channel::KernelExecutionStrategy; +use cubecl_runtime::ExecutionMode; use std::any::{Any, TypeId}; use std::fmt::Display; use std::hash::{DefaultHasher, Hash, Hasher}; @@ -9,7 +9,7 @@ use std::sync::Arc; pub struct KernelId { type_id: core::any::TypeId, info: Option, - kind: Option, + mode: Option, } impl Display for KernelId { @@ -27,7 +27,7 @@ impl KernelId { Self { type_id: core::any::TypeId::of::(), info: None, - kind: None, + mode: None, } } @@ -43,9 +43,9 @@ impl KernelId { self } - /// Set the kind of checking strategy. - pub fn kind(&mut self, kind: KernelExecutionStrategy) { - self.kind = Some(kind); + /// Set the [execution mode](ExecutionMode). + pub fn mode(&mut self, mode: ExecutionMode) { + self.mode = Some(mode); } } diff --git a/crates/cubecl-core/src/runtime.rs b/crates/cubecl-core/src/runtime.rs index 21683dea9..b9a42cbf5 100644 --- a/crates/cubecl-core/src/runtime.rs +++ b/crates/cubecl-core/src/runtime.rs @@ -9,6 +9,7 @@ pub use cubecl_runtime::channel; pub use cubecl_runtime::client; pub use cubecl_runtime::server; pub use cubecl_runtime::tune; +pub use cubecl_runtime::ExecutionMode; /// Runtime for the CubeCL. pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { diff --git a/crates/cubecl-core/src/runtime_tests/subcube.rs b/crates/cubecl-core/src/runtime_tests/subcube.rs index 7fc50b2d8..f9bbc0578 100644 --- a/crates/cubecl-core/src/runtime_tests/subcube.rs +++ b/crates/cubecl-core/src/runtime_tests/subcube.rs @@ -109,11 +109,13 @@ fn test_subcube_operation( let handle = client.create(f32::as_bytes(input)); let (shape, strides) = ([input.len()], [1]); - launch( - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - TensorArg::new(&handle, &strides, &shape), - ); + unsafe { + launch( + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32, 1, 1), + TensorArg::from_raw_parts(&handle, &strides, &shape, 1), + ); + } let actual = client.read(handle.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 73e162e7c..403a0a0e3 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -1,10 +1,10 @@ use std::collections::HashSet; use cubecl_core::{ - channel::KernelExecutionStrategy, ir::{self as gpu, ConstantScalarValue}, Compiler, }; +use cubecl_runtime::ExecutionMode; use super::{Instruction, WarpInstruction}; @@ -26,7 +26,7 @@ pub struct CudaCompiler { num_inputs: usize, num_outputs: usize, items: HashSet, - strategy: KernelExecutionStrategy, + strategy: ExecutionMode, } impl Compiler for CudaCompiler { @@ -34,7 +34,7 @@ impl Compiler for CudaCompiler { fn compile( kernel: cubecl_core::ir::KernelDefinition, - strategy: KernelExecutionStrategy, + strategy: ExecutionMode, ) -> Self::Representation { let mut compiler = Self::default(); compiler.strategy = strategy; @@ -348,7 +348,7 @@ impl CudaCompiler { out: self.compile_variable(op.out), }), gpu::Operator::Index(op) => { - if let KernelExecutionStrategy::Checked = self.strategy { + if let ExecutionMode::Checked = self.strategy { let has_len = match op.lhs { gpu::Variable::GlobalInputArray { .. } => true, gpu::Variable::GlobalOutputArray { .. } => true, @@ -376,7 +376,7 @@ impl CudaCompiler { instructions.push(Instruction::Index(self.compile_binary(op))) } gpu::Operator::IndexAssign(op) => { - if let KernelExecutionStrategy::Checked = self.strategy { + if let ExecutionMode::Checked = self.strategy { let has_len = match op.out { gpu::Variable::GlobalInputArray { .. } => true, gpu::Variable::GlobalOutputArray { .. } => true, diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 0ab57a568..729b47c75 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -4,12 +4,12 @@ use super::storage::CudaStorage; use super::CudaResource; use cubecl_common::reader::{reader_from_concrete, Reader}; use cubecl_common::sync_type::SyncType; -use cubecl_core::channel::KernelExecutionStrategy; use cubecl_core::compute::DebugInformation; use cubecl_core::ir::CubeDim; use cubecl_core::FeatureSet; use cubecl_core::{prelude::*, KernelId}; use cubecl_runtime::debug::DebugLogger; +use cubecl_runtime::ExecutionMode; use cubecl_runtime::{ memory_management::MemoryManagement, server::{self, ComputeServer}, @@ -112,17 +112,17 @@ impl> ComputeServer for CudaServer { server::Handle::new(handle) } - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, - strategy: KernelExecutionStrategy, + mode: ExecutionMode, ) { let arch = self.minimum_arch_version; let mut kernel_id = kernel.id(); - kernel_id.kind(strategy); + kernel_id.mode(mode); let count = match count { CubeCount::Static(x, y, z) => (x, y, z), @@ -143,7 +143,7 @@ impl> ComputeServer for CudaServer { let (ctx, logger) = self.get_context_with_logger(); if !ctx.module_names.contains_key(&kernel_id) { - ctx.compile_kernel(&kernel_id, kernel, arch, logger, strategy); + ctx.compile_kernel(&kernel_id, kernel, arch, logger, mode); } let resources = bindings @@ -201,9 +201,9 @@ impl> CudaContext { kernel: Box, arch: i32, logger: &mut DebugLogger, - strategy: KernelExecutionStrategy, + mode: ExecutionMode, ) { - let mut kernel_compiled = kernel.compile(strategy); + let mut kernel_compiled = kernel.compile(mode); if logger.is_activated() { kernel_compiled.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone())); @@ -235,7 +235,7 @@ impl> CudaContext { message += format!("\n {line}").as_str(); } } - let source = kernel.compile(strategy).source; + let source = kernel.compile(mode).source; panic!("{message}\n[Source] \n{source}"); }; cudarc::nvrtc::result::get_ptx(program).unwrap() diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index e76fcaf40..396d68715 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -162,9 +162,9 @@ fn matmul_cmma_ref_no_check( client, cube_count, cube_dim, - TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), - TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), - TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), + TensorArg::from_raw_parts(lhs.handle, lhs.strides, lhs.shape, lhs_vectorization), + TensorArg::from_raw_parts(rhs.handle, rhs.strides, rhs.shape, rhs_vectorization), + TensorArg::from_raw_parts(out.handle, out.strides, out.shape, out_vectorization), CmmaConfig::new(m, k, n, launch_config), ); } diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index a0ed7ef21..40adcd380 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -89,8 +89,8 @@ pub fn compute_loop_k_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), ArrayArg::new(&results, m * n), UInt::new(m as u32), UInt::new(k as u32), @@ -167,8 +167,8 @@ pub fn compute_loop_warp_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), ArrayArg::new(&results, m * n), UInt::new(m as u32), UInt::new(k as u32), @@ -274,8 +274,8 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De &client, cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), ArrayArg::new(&results, m * n), UInt::new(m as u32), UInt::new(k as u32), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index 4894bca1e..cf90bccda 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -98,11 +98,11 @@ pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -157,11 +157,11 @@ pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &rhs_tensor.handle, &rhs_tensor.strides, &rhs_tensor.shape, + 4, ), ArrayArg::new(&rhs_sm, 64 * 32), ScalarArg::new(0), @@ -216,11 +216,11 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -280,11 +280,11 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device &R::client(device), cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -342,11 +342,11 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -404,11 +404,11 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -465,11 +465,11 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &rhs_tensor.handle, &rhs_tensor.strides, &rhs_tensor.shape, + 4, ), ArrayArg::new(&rhs_sm, 64 * 32), ScalarArg::new(0), @@ -529,11 +529,11 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -592,11 +592,11 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &rhs_tensor.handle, &rhs_tensor.strides, &rhs_tensor.shape, + 4, ), ArrayArg::new(&rhs_sm, 64 * 32), ScalarArg::new(0), @@ -658,11 +658,11 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(0), @@ -724,11 +724,11 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &rhs_tensor.handle, &rhs_tensor.strides, &rhs_tensor.shape, + 4, ), ArrayArg::new(&rhs_sm, 64 * 32), ScalarArg::new(0), @@ -787,11 +787,11 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &lhs_tensor.handle, &lhs_tensor.strides, &lhs_tensor.shape, + 4, ), ArrayArg::new(&lhs_sm, 64 * 32), ScalarArg::new(32), @@ -850,11 +850,11 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized( - 4, + TensorArg::from_raw_parts( &rhs_tensor.handle, &rhs_tensor.strides, &rhs_tensor.shape, + 4, ), ArrayArg::new(&rhs_sm, 64 * 32), ScalarArg::new(32), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index 597075172..2f090ac27 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -65,7 +65,7 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), @@ -133,7 +133,7 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), @@ -211,7 +211,7 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), @@ -284,7 +284,7 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), @@ -357,7 +357,7 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), @@ -426,7 +426,7 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), @@ -544,7 +544,7 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) &client, cube_count, cube_dim, - TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), ArrayArg::new(&acc_sm.handle, 64 * 64), ScalarArg::new(m as u32), ScalarArg::new(n as u32), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index 6754ada1f..bcd7fda7e 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -161,8 +161,8 @@ pub fn compute_loop_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, TILE_SIZE as u8), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ScalarArg::new(0), ScalarArg::new(0), ArrayArg::new(&results, 16), @@ -195,8 +195,8 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, TILE_SIZE as u8), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ArrayArg::new(&results, 16), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index f91d0d9b0..a9a3e07dd 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -227,7 +227,7 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(4), ScalarArg::new(4), @@ -262,11 +262,11 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic &client, cube_count, cube_dim, - TensorArg::vectorized( - vectorization_factor as u8, + TensorArg::from_raw_parts( &lhs.handle, &lhs.strides, &lhs.shape, + vectorization_factor as u8, ), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(0), @@ -299,7 +299,7 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(0), config, @@ -332,7 +332,7 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(8), config, @@ -365,7 +365,7 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(4), ScalarArg::new(4), @@ -399,7 +399,7 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(0), config, @@ -432,7 +432,7 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(8), config, @@ -465,7 +465,7 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(4), ScalarArg::new(4), @@ -500,7 +500,7 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(4), ScalarArg::new(4), @@ -534,7 +534,7 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { &client, cube_count, cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(4), ScalarArg::new(4), @@ -569,7 +569,7 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic &client, cube_count, cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), ScalarArg::new(4), ScalarArg::new(4), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index 852d29c22..6d83ed888 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -71,7 +71,7 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, TILE_SIZE as u8), ArrayArg::new(&tile.handle, 16), config, ); @@ -100,7 +100,7 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, TILE_SIZE as u8), ArrayArg::new(&tile.handle, 16), config, ); @@ -129,7 +129,7 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization as u8), ArrayArg::new(&tile.handle, 16), config, ); @@ -160,7 +160,7 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(vectorization as u8, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization as u8), ArrayArg::new(&tile.handle, 16), config, ); @@ -191,7 +191,7 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De &R::client(device), cube_count, cube_dim, - TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization), ArrayArg::new(&results.handle, 16), config, ); diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index 3f966d249..8f29adf60 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -130,9 +130,9 @@ fn matmul_tiling_2d_ref_no_check( client, cube_count, cube_dim, - TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), - TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), - TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), + TensorArg::from_raw_parts(lhs.handle, lhs.strides, lhs.shape, lhs_vectorization), + TensorArg::from_raw_parts(rhs.handle, rhs.strides, rhs.shape, rhs_vectorization), + TensorArg::from_raw_parts(out.handle, out.strides, out.shape, out_vectorization), cube_config, ); } diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 5f9513378..e26b3afa0 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -73,18 +73,8 @@ pub fn into_contiguous( client, cube_count, cube_dim, - TensorArg::vectorized( - vectorization_factor, - input.handle, - input.strides, - input.shape, - ), - TensorArg::vectorized( - vectorization_factor, - &output.handle, - &output.strides, - &output.shape, - ), + input.as_tensor_arg(vectorization_factor), + output.as_ref().as_tensor_arg(vectorization_factor), Some(UInt::new(rank as u32)), ); diff --git a/crates/cubecl-runtime/src/compute.rs b/crates/cubecl-runtime/src/base.rs similarity index 92% rename from crates/cubecl-runtime/src/compute.rs rename to crates/cubecl-runtime/src/base.rs index 9a35f5384..5ad3f1031 100644 --- a/crates/cubecl-runtime/src/compute.rs +++ b/crates/cubecl-runtime/src/base.rs @@ -8,6 +8,16 @@ pub struct ComputeRuntime { clients: spin::Mutex>>>, } +/// The kind of execution to be performed. +#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)] +pub enum ExecutionMode { + /// Checked kernels are safe. + #[default] + Checked, + /// Unchecked kernels are unsafe. + Unchecked, +} + impl Default for ComputeRuntime where Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index b7510784e..ff3782bb3 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -1,20 +1,11 @@ use crate::{ server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, + ExecutionMode, }; use alloc::vec::Vec; use cubecl_common::{reader::Reader, sync_type::SyncType}; -/// The kind of execution to be performed. -#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)] -pub enum KernelExecutionStrategy { - /// Checked kernels are safe. - #[default] - Checked, - /// Unchecked kernels are unsafe. - Unchecked, -} - /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { @@ -34,12 +25,12 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send fn empty(&self, size: usize) -> Handle; /// Executes the `kernel` over the given `bindings`. - fn execute( + unsafe fn execute( &self, kernel: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, - strategy: KernelExecutionStrategy, + mode: ExecutionMode, ); /// Perform some synchronization of commands on the server. diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs index 897408c31..e9178ac30 100644 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ b/crates/cubecl-runtime/src/channel/cell.rs @@ -1,6 +1,7 @@ -use super::{ComputeChannel, KernelExecutionStrategy}; +use super::ComputeChannel; use crate::server::{Binding, ComputeServer, Handle}; use crate::storage::ComputeStorage; +use crate::ExecutionMode; use alloc::sync::Arc; use alloc::vec::Vec; use cubecl_common::reader::Reader; @@ -63,12 +64,12 @@ where self.server.borrow_mut().empty(size) } - fn execute( + unsafe fn execute( &self, kernel_description: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, - kind: KernelExecutionStrategy, + kind: ExecutionMode, ) { self.server .borrow_mut() diff --git a/crates/cubecl-runtime/src/channel/mpsc.rs b/crates/cubecl-runtime/src/channel/mpsc.rs index a131e1d9e..3488635dc 100644 --- a/crates/cubecl-runtime/src/channel/mpsc.rs +++ b/crates/cubecl-runtime/src/channel/mpsc.rs @@ -1,10 +1,11 @@ use cubecl_common::{reader::Reader, sync_type::SyncType}; use std::{sync::Arc, thread}; -use super::{ComputeChannel, KernelExecutionStrategy}; +use super::ComputeChannel; use crate::{ server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, + ExecutionMode, }; /// Create a channel using a [multi-producer, single-consumer channel to communicate with @@ -40,11 +41,7 @@ where Create(Vec, Callback>), Empty(usize, Callback>), ExecuteKernel( - ( - Server::Kernel, - Server::DispatchOptions, - KernelExecutionStrategy, - ), + (Server::Kernel, Server::DispatchOptions, ExecutionMode), Vec>, ), Sync(SyncType, Callback<()>), @@ -80,9 +77,9 @@ where let handle = server.empty(size); callback.send(handle).await.unwrap(); } - Message::ExecuteKernel(kernel, bindings) => { + Message::ExecuteKernel(kernel, bindings) => unsafe { server.execute(kernel.0, kernel.1, bindings, kernel.2); - } + }, Message::Sync(sync_type, callback) => { server.sync(sync_type); callback.send(()).await.unwrap(); @@ -155,12 +152,12 @@ where handle_response(response.recv_blocking()) } - fn execute( + unsafe fn execute( &self, kernel: Server::Kernel, count: Server::DispatchOptions, bindings: Vec>, - kind: KernelExecutionStrategy, + kind: ExecutionMode, ) { self.state .sender diff --git a/crates/cubecl-runtime/src/channel/mutex.rs b/crates/cubecl-runtime/src/channel/mutex.rs index e63fd9a53..f20fce89a 100644 --- a/crates/cubecl-runtime/src/channel/mutex.rs +++ b/crates/cubecl-runtime/src/channel/mutex.rs @@ -1,6 +1,7 @@ -use super::{ComputeChannel, KernelExecutionStrategy}; +use super::ComputeChannel; use crate::server::{Binding, ComputeServer, Handle}; use crate::storage::ComputeStorage; +use crate::ExecutionMode; use alloc::sync::Arc; use alloc::vec::Vec; use cubecl_common::reader::Reader; @@ -56,12 +57,12 @@ where self.server.lock().empty(size) } - fn execute( + unsafe fn execute( &self, kernel: Server::Kernel, count: Server::DispatchOptions, handles: Vec>, - kind: KernelExecutionStrategy, + kind: ExecutionMode, ) { self.server.lock().execute(kernel, count, handles, kind) } diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index db3693564..5ff6a92d3 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -1,7 +1,8 @@ use crate::{ - channel::{ComputeChannel, KernelExecutionStrategy}, + channel::ComputeChannel, server::{Binding, ComputeServer, Handle}, storage::ComputeStorage, + ExecutionMode, }; use alloc::sync::Arc; use alloc::vec::Vec; @@ -77,8 +78,10 @@ where count: Server::DispatchOptions, bindings: Vec>, ) { - self.channel - .execute(kernel, count, bindings, KernelExecutionStrategy::Checked) + unsafe { + self.channel + .execute(kernel, count, bindings, ExecutionMode::Checked) + } } /// Executes the `kernel` over the given `bindings` without performing any bound checks. @@ -89,7 +92,7 @@ where bindings: Vec>, ) { self.channel - .execute(kernel, count, bindings, KernelExecutionStrategy::Unchecked) + .execute(kernel, count, bindings, ExecutionMode::Unchecked) } /// Wait for the completion of every task in the server. diff --git a/crates/cubecl-runtime/src/lib.rs b/crates/cubecl-runtime/src/lib.rs index 307021bf5..dd278706a 100644 --- a/crates/cubecl-runtime/src/lib.rs +++ b/crates/cubecl-runtime/src/lib.rs @@ -25,8 +25,8 @@ pub mod server; /// Compute Storage module. pub mod storage; -mod compute; -pub use compute::*; +mod base; +pub use base::*; pub use cubecl_common::benchmark; /// Debugging utilities. diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 5597c48ef..45f78b664 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -1,7 +1,7 @@ use crate::{ - channel::KernelExecutionStrategy, memory_management::{MemoryHandle, MemoryManagement}, storage::ComputeStorage, + ExecutionMode, }; use alloc::vec::Vec; use core::fmt::Debug; @@ -45,12 +45,12 @@ where /// /// Kernels have mutable access to every resource they are given /// and are responsible of determining which should be read or written. - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, - kind: KernelExecutionStrategy, + kind: ExecutionMode, ); /// Wait for the completion of every task in the server. diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index 5daf5280e..2f5ade31f 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -2,10 +2,10 @@ use std::sync::Arc; use cubecl_common::{reader::reader_from_concrete, sync_type::SyncType}; use cubecl_runtime::{ - channel::KernelExecutionStrategy, memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, server::{Binding, ComputeServer, Handle}, storage::{BytesResource, BytesStorage}, + ExecutionMode, }; use derive_new::new; @@ -54,12 +54,12 @@ where Handle::new(self.memory_management.reserve(size, || {})) } - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, _count: Self::DispatchOptions, bindings: Vec>, - _strategy: KernelExecutionStrategy, + _mode: ExecutionMode, ) { let mut resources = bindings .into_iter() diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index c50e1964f..c3547ce43 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -1,8 +1,8 @@ use super::{shader::ComputeShader, Item, SharedMemory}; use super::{LocalArray, Subgroup}; use crate::compiler::wgsl; -use cubecl_core::channel::KernelExecutionStrategy; use cubecl_core::ir as cube; +use cubecl_runtime::ExecutionMode; /// Wgsl Compiler. #[derive(Clone, Default)] @@ -34,10 +34,7 @@ impl core::fmt::Debug for WgslCompiler { impl cubecl_core::Compiler for WgslCompiler { type Representation = ComputeShader; - fn compile( - shader: cube::KernelDefinition, - _strategy: KernelExecutionStrategy, - ) -> Self::Representation { + fn compile(shader: cube::KernelDefinition, _mode: ExecutionMode) -> Self::Representation { let mut compiler = Self::default(); compiler.compile_shader(shader) } diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index e4a639d22..4d908d99b 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -3,13 +3,12 @@ use std::num::NonZeroU64; use super::WgpuStorage; use alloc::{borrow::Cow, sync::Arc}; use cubecl_common::{reader::Reader, sync_type::SyncType}; -use cubecl_core::{ - channel::KernelExecutionStrategy, compute::DebugInformation, prelude::*, FeatureSet, KernelId, -}; +use cubecl_core::{compute::DebugInformation, prelude::*, FeatureSet, KernelId}; use cubecl_runtime::{ debug::DebugLogger, memory_management::MemoryManagement, server::{self, ComputeServer}, + ExecutionMode, }; use hashbrown::HashMap; use wgpu::{ @@ -103,41 +102,35 @@ where fn pipeline( &mut self, kernel: ::Kernel, - strategy: KernelExecutionStrategy, + mode: ExecutionMode, ) -> Arc { let mut kernel_id = kernel.id(); - kernel_id.kind(strategy); + kernel_id.mode(mode); if let Some(pipeline) = self.pipelines.get(&kernel_id) { return pipeline.clone(); } - let mut compile = kernel.compile(strategy); + let mut compile = kernel.compile(mode); if self.logger.is_activated() { compile.debug_info = Some(DebugInformation::new("wgsl", kernel_id.clone())); } let compile = self.logger.debug(compile); - let pipeline = self.compile_source(&compile.source, strategy); + let pipeline = self.compile_source(&compile.source, mode); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); pipeline } - fn compile_source( - &self, - source: &str, - strategy: KernelExecutionStrategy, - ) -> Arc { - let module = match strategy { - KernelExecutionStrategy::Checked => { - self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }) - } - KernelExecutionStrategy::Unchecked => unsafe { + fn compile_source(&self, source: &str, mode: ExecutionMode) -> Arc { + let module = match mode { + ExecutionMode::Checked => self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }), + ExecutionMode::Unchecked => unsafe { self.device .create_shader_module_unchecked(ShaderModuleDescriptor { label: None, @@ -305,14 +298,14 @@ where })) } - fn execute( + unsafe fn execute( &mut self, kernel: Self::Kernel, count: Self::DispatchOptions, bindings: Vec>, - kind: KernelExecutionStrategy, + mode: ExecutionMode, ) { - let pipeline = self.pipeline(kernel, kind); + let pipeline = self.pipeline(kernel, mode); let group_layout = pipeline.get_bind_group_layout(0); let memory_handles = bindings diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs index 17ed0148e..869632cd9 100644 --- a/crates/cubecl/benches/unary.rs +++ b/crates/cubecl/benches/unary.rs @@ -44,9 +44,9 @@ impl Benchmark for UnaryBench { &self.client, cube_count, cube_dim, - TensorArg::vectorized(self.vectorization, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(self.vectorization, &rhs.handle, &rhs.strides, &rhs.shape), - TensorArg::vectorized(self.vectorization, &out.handle, &out.strides, &out.shape), + TensorArg::from_raw_parts(self.vectorization, &lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::from_raw_parts(self.vectorization, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::from_raw_parts(self.vectorization, &out.handle, &out.strides, &out.shape), ) } diff --git a/crates/cubecl/src/runtime.rs b/crates/cubecl/src/runtime.rs deleted file mode 100644 index e69de29bb..000000000 From 03a9134a264593849837d03ba31084531f3ba6c9 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Aug 2024 17:41:52 -0400 Subject: [PATCH 6/9] Fix --- crates/cubecl-linalg/src/tensor/base.rs | 9 +++++++++ crates/cubecl/benches/unary.rs | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index 0eb5b1cfb..75dd038f3 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -91,6 +91,15 @@ where } } + /// Return the reference to a tensor argument. + pub fn as_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + let handle: TensorHandleRef<'a, R> = self.as_ref(); + + unsafe { + TensorArg::from_raw_parts(handle.handle, handle.strides, handle.shape, vectorisation) + } + } + fn contiguous_strides(shape: &[usize]) -> Vec { let mut strides = Vec::with_capacity(shape.len()); diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs index 869632cd9..99ab027d2 100644 --- a/crates/cubecl/benches/unary.rs +++ b/crates/cubecl/benches/unary.rs @@ -44,9 +44,9 @@ impl Benchmark for UnaryBench { &self.client, cube_count, cube_dim, - TensorArg::from_raw_parts(self.vectorization, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::from_raw_parts(self.vectorization, &rhs.handle, &rhs.strides, &rhs.shape), - TensorArg::from_raw_parts(self.vectorization, &out.handle, &out.strides, &out.shape), + lhs.as_arg(self.vectorization), + rhs.as_arg(self.vectorization), + out.as_arg(self.vectorization), ) } From 3f04aeaa8f4a21d57fef65c3852b3b66f0bfda85 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Aug 2024 19:10:05 -0400 Subject: [PATCH 7/9] Other cast --- crates/cubecl-cuda/src/compiler/binary.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs index bd73c2bc2..34f0fb3c4 100644 --- a/crates/cubecl-cuda/src/compiler/binary.rs +++ b/crates/cubecl-cuda/src/compiler/binary.rs @@ -169,13 +169,13 @@ impl Binary for IndexAssign { _ => panic!("Invalid"), }; if cast { - f.write_fmt(format_args!("{var}.{char} = ({}){};\n", elem, rhs.index(i)))?; + f.write_fmt(format_args!("{var}.{char} = {}({});\n", elem, rhs.index(i)))?; } else { f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?; } } else { if cast { - f.write_fmt(format_args!("{var}.i_{i} = ({}){};\n", elem, rhs.index(i)))?; + f.write_fmt(format_args!("{var}.i_{i} = {}({});\n", elem, rhs.index(i)))?; } else { f.write_fmt(format_args!("{var}.i_{i} = {};\n", rhs.index(i)))?; } @@ -193,7 +193,7 @@ impl Binary for IndexAssign { if item_out.vectorization > 1 { format_vec(f, true)?; } else { - f.write_fmt(format_args!("{out}[{lhs}] = ({elem}){rhs};\n"))?; + f.write_fmt(format_args!("{out}[{lhs}] = {elem}({rhs});\n"))?; } Ok(()) } else { @@ -307,7 +307,7 @@ impl Binary for Index { if item_out.vectorization > 1 { format_vec(f)?; } else { - f.write_fmt(format_args!("{out} = ({elem}){lhs}[{rhs}];\n"))?; + f.write_fmt(format_args!("{out} = {elem}({lhs}[{rhs}]);\n"))?; } Ok(()) } else { From fdc3aefeb1bd04f8e3f302c617157af81f57c54a Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 7 Aug 2024 09:29:16 -0400 Subject: [PATCH 8/9] Cleanup --- crates/cubecl-core/src/compute/launcher.rs | 4 +++ .../src/frontend/element/tensor.rs | 4 +-- crates/cubecl-cuda/src/compiler/base.rs | 32 +++++++++---------- crates/cubecl-cuda/src/compiler/binary.rs | 8 ++--- crates/cubecl-runtime/src/channel/base.rs | 4 +++ crates/cubecl-runtime/src/client.rs | 4 +++ crates/cubecl-runtime/src/server.rs | 4 +++ 7 files changed, 36 insertions(+), 24 deletions(-) diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 66ed56b71..3038a007a 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -90,6 +90,10 @@ impl KernelLauncher { } /// Launch the kernel without check bounds. + /// + /// # Safety + /// + /// Out-of-bounds reads and writes can happen. pub unsafe fn launch_unchecked( self, cube_count: CubeCount, diff --git a/crates/cubecl-core/src/frontend/element/tensor.rs b/crates/cubecl-core/src/frontend/element/tensor.rs index 02600310e..9ffce8e6d 100644 --- a/crates/cubecl-core/src/frontend/element/tensor.rs +++ b/crates/cubecl-core/src/frontend/element/tensor.rs @@ -68,7 +68,7 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> { /// # Safety /// /// If you provide wrong strides or shapes, it might create undefined behavior caused by - /// out of bound reads and writes. + /// out-of-bounds reads and writes. pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, strides: &'a [usize], @@ -104,7 +104,7 @@ impl<'a, R: Runtime> TensorArg<'a, R> { /// # Safety /// /// If you provide wrong strides or shapes, it might create undefined behavior caused by - /// out of bound reads and writes. + /// out-of-bound reads and writes. pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, strides: &'a [usize], diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 403a0a0e3..eada5bddc 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -36,8 +36,10 @@ impl Compiler for CudaCompiler { kernel: cubecl_core::ir::KernelDefinition, strategy: ExecutionMode, ) -> Self::Representation { - let mut compiler = Self::default(); - compiler.strategy = strategy; + let compiler = Self { + strategy, + ..Self::default() + }; compiler.compile_shader(kernel) } @@ -349,13 +351,7 @@ impl CudaCompiler { }), gpu::Operator::Index(op) => { if let ExecutionMode::Checked = self.strategy { - let has_len = match op.lhs { - gpu::Variable::GlobalInputArray { .. } => true, - gpu::Variable::GlobalOutputArray { .. } => true, - gpu::Variable::Slice { .. } => true, - _ => false, - }; - if has_len { + if has_length(&op.lhs) { self.compile_procedure( instructions, gpu::Procedure::CheckedIndex(gpu::CheckedIndex { @@ -377,14 +373,7 @@ impl CudaCompiler { } gpu::Operator::IndexAssign(op) => { if let ExecutionMode::Checked = self.strategy { - let has_len = match op.out { - gpu::Variable::GlobalInputArray { .. } => true, - gpu::Variable::GlobalOutputArray { .. } => true, - gpu::Variable::Slice { .. } => true, - _ => false, - }; - - if has_len { + if has_length(&op.out) { self.compile_procedure( instructions, gpu::Procedure::CheckedIndexAssign(gpu::CheckedIndexAssign { @@ -679,3 +668,12 @@ impl CudaCompiler { } } } + +fn has_length(var: &gpu::Variable) -> bool { + matches!( + var, + gpu::Variable::GlobalInputArray { .. } + | gpu::Variable::GlobalOutputArray { .. } + | gpu::Variable::Slice { .. } + ) +} diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs index 34f0fb3c4..44d5fae63 100644 --- a/crates/cubecl-cuda/src/compiler/binary.rs +++ b/crates/cubecl-cuda/src/compiler/binary.rs @@ -173,12 +173,10 @@ impl Binary for IndexAssign { } else { f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?; } + } else if cast { + f.write_fmt(format_args!("{var}.i_{i} = {}({});\n", elem, rhs.index(i)))?; } else { - if cast { - f.write_fmt(format_args!("{var}.i_{i} = {}({});\n", elem, rhs.index(i)))?; - } else { - f.write_fmt(format_args!("{var}.i_{i} = {};\n", rhs.index(i)))?; - } + f.write_fmt(format_args!("{var}.i_{i} = {};\n", rhs.index(i)))?; } } f.write_fmt(format_args!("{out}[{lhs}] = {var};\n"))?; diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index ff3782bb3..c1fb44a22 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -25,6 +25,10 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send fn empty(&self, size: usize) -> Handle; /// Executes the `kernel` over the given `bindings`. + /// + /// # Safety + /// + /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen. unsafe fn execute( &self, kernel: Server::Kernel, diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 5ff6a92d3..25cf9b7b7 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -85,6 +85,10 @@ where } /// Executes the `kernel` over the given `bindings` without performing any bound checks. + /// + /// # Safety + /// + /// Without checks, the out-of-bound reads and writes can happen. pub unsafe fn execute_unchecked( &self, kernel: Server::Kernel, diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 45f78b664..dbc104f5e 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -45,6 +45,10 @@ where /// /// Kernels have mutable access to every resource they are given /// and are responsible of determining which should be read or written. + /// + /// # Safety + /// + /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen. unsafe fn execute( &mut self, kernel: Self::Kernel, From 0e5a6bb4c811bf6ead853ec84c4445e85f795204 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 8 Aug 2024 14:41:51 -0400 Subject: [PATCH 9/9] Migrate Array --- README.md | 24 +++++----- crates/cubecl-core/src/compute/kernel.rs | 8 ++-- .../cubecl-core/src/frontend/element/array.rs | 31 ++++++------ .../cubecl-core/src/runtime_tests/assign.rs | 2 +- crates/cubecl-core/src/runtime_tests/cmma.rs | 18 +++---- .../cubecl-core/src/runtime_tests/launch.rs | 4 +- .../cubecl-core/src/runtime_tests/sequence.rs | 4 +- crates/cubecl-core/src/runtime_tests/slice.rs | 48 +++++++++++-------- .../cubecl-core/src/runtime_tests/topology.rs | 16 ++++--- .../src/matmul/tests/cmma/compute_loop.rs | 6 +-- .../matmul/tests/cmma/load_shared_memory.rs | 26 +++++----- .../src/matmul/tests/cmma/write_output.rs | 14 +++--- .../src/matmul/tests/tiling2d/compute_loop.rs | 16 +++---- .../tests/tiling2d/load_shared_memory.rs | 22 ++++----- .../src/matmul/tests/tiling2d/write_output.rs | 10 ++-- crates/cubecl-linalg/src/tensor/base.rs | 16 ++++--- examples/gelu/src/lib.rs | 18 +++---- 17 files changed, 151 insertions(+), 132 deletions(-) diff --git a/README.md b/README.md index dcab44293..3ea54fbce 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Simply annotate functions with the `cube` attribute to indicate that they should ```rust use cubecl::prelude::*; -#[cube(launch)] +#[cube(launch_unchecked)] fn gelu_array(input: &Array, output: &mut Array) { if ABSOLUTE_POS < input.len() { output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); @@ -38,24 +38,26 @@ fn gelu_array(input: &Array, output: &mut Array) { fn gelu_scalar(x: F) -> F { x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0 } - ``` -You can then launch the kernel using the autogenerated `gelu_array::launch` function. +You can then launch the kernel using the autogenerated `gelu_array::launch_unchecked` function. ```rust fn launch(device: &R::Device) { let client = R::client(device); let input = &[-1., 0., 1., 5.]; let output_handle = client.empty(input.len() * core::mem::size_of::()); - - gelu_array::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()), - ArrayArg::new(&output_handle, input.len()), - ); + let input_handle = client.create(f32::as_bytes(input)); + + unsafe { + gelu_array::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32, 1, 1), + ArrayArg::from_raw_parts(&input_handle, input.len(), 1), + ArrayArg::from_raw_parts(&output_handle, input.len(), 1), + ) + }; let bytes = client.read(output_handle.binding()); let output = f32::from_bytes(&bytes); diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index d0031b3a8..3e3175631 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -193,8 +193,8 @@ impl CubeTask for KernelTask { } impl CubeTask for Arc { - fn compile(&self, kind: ExecutionMode) -> CompiledKernel { - self.as_ref().compile(kind) + fn compile(&self, mode: ExecutionMode) -> CompiledKernel { + self.as_ref().compile(mode) } fn id(&self) -> KernelId { @@ -203,8 +203,8 @@ impl CubeTask for Arc { } impl CubeTask for Box { - fn compile(&self, kind: ExecutionMode) -> CompiledKernel { - self.as_ref().compile(kind) + fn compile(&self, mode: ExecutionMode) -> CompiledKernel { + self.as_ref().compile(mode) } fn id(&self) -> KernelId { diff --git a/crates/cubecl-core/src/frontend/element/array.rs b/crates/cubecl-core/src/frontend/element/array.rs index b5e1a36fb..d3cad4bde 100644 --- a/crates/cubecl-core/src/frontend/element/array.rs +++ b/crates/cubecl-core/src/frontend/element/array.rs @@ -144,7 +144,7 @@ impl LaunchArgExpand for Array { /// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle). pub struct ArrayHandleRef<'a, R: Runtime> { pub handle: &'a cubecl_runtime::server::Handle, - pub length: [usize; 1], + pub(crate) length: [usize; 1], } pub enum ArrayArg<'a, R: Runtime> { @@ -205,35 +205,38 @@ impl<'a, R: Runtime> ArgSettings for ArrayArg<'a, R> { impl<'a, R: Runtime> ArrayArg<'a, R> { /// Create a new array argument. /// - /// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization - /// factor of 1. - pub fn new(handle: &'a cubecl_runtime::server::Handle, length: usize) -> Self { - ArrayArg::Handle { - handle: ArrayHandleRef::new(handle, length), - vectorization_factor: 1, - } - } - /// Create a new array argument specified with its vectorization factor. - pub fn vectorized( - vectorization_factor: u8, + /// # Safety + /// + /// Specifying the wrong lenght may lead to out-of-bounds reads and writes. + pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, length: usize, + vectorization_factor: u8, ) -> Self { ArrayArg::Handle { - handle: ArrayHandleRef::new(handle, length), + handle: ArrayHandleRef::from_raw_parts(handle, length), vectorization_factor, } } } impl<'a, R: Runtime> ArrayHandleRef<'a, R> { - pub fn new(handle: &'a cubecl_runtime::server::Handle, length: usize) -> Self { + /// Create a new array handle reference. + /// + /// # Safety + /// + /// Specifying the wrong lenght may lead to out-of-bounds reads and writes. + pub unsafe fn from_raw_parts( + handle: &'a cubecl_runtime::server::Handle, + length: usize, + ) -> Self { Self { handle, length: [length], } } + /// Return the handle as a tensor instead of an array. pub fn as_tensor(&self) -> TensorHandleRef<'_, R> { let shape = &self.length; diff --git a/crates/cubecl-core/src/runtime_tests/assign.rs b/crates/cubecl-core/src/runtime_tests/assign.rs index 08dfd77f2..f9c81aae5 100644 --- a/crates/cubecl-core/src/runtime_tests/assign.rs +++ b/crates/cubecl-core/src/runtime_tests/assign.rs @@ -19,7 +19,7 @@ pub fn test_kernel_assign_scalar(client: ComputeClient(client: ComputeClient) { let rhs = client.create(f16::as_bytes(&rhs)); let out = client.empty(core::mem::size_of::() * 256); - kernel_simple_1::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(16, 16, 1), - ArrayArg::new(&lhs, 256), - ArrayArg::new(&rhs, 256), - ArrayArg::new(&out, 256), - ); + unsafe { + kernel_simple_1::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(16, 16, 1), + ArrayArg::from_raw_parts(&lhs, 256, 1), + ArrayArg::from_raw_parts(&rhs, 256, 1), + ArrayArg::from_raw_parts(&out, 256, 1), + ) + }; let actual = client.read(out.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-core/src/runtime_tests/launch.rs b/crates/cubecl-core/src/runtime_tests/launch.rs index 62bc50e2f..38c7d204c 100644 --- a/crates/cubecl-core/src/runtime_tests/launch.rs +++ b/crates/cubecl-core/src/runtime_tests/launch.rs @@ -23,7 +23,7 @@ pub fn test_kernel_with_generics(client: ComputeClient(client: ComputeClient(client: ComputeClient(client: ComputeClient(client: ComputeClient()); - slice_select::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); + unsafe { + slice_select::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); @@ -48,13 +50,15 @@ pub fn test_slice_len(client: ComputeClient) let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); let output = client.empty(core::mem::size_of::()); - slice_len::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); + unsafe { + slice_len::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; let actual = client.read(output.binding()); let actual = u32::from_bytes(&actual); @@ -66,13 +70,15 @@ pub fn test_slice_assign(client: ComputeClient( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); + unsafe { + slice_assign::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-core/src/runtime_tests/topology.rs b/crates/cubecl-core/src/runtime_tests/topology.rs index 1b2df08e3..cc9d687e7 100644 --- a/crates/cubecl-core/src/runtime_tests/topology.rs +++ b/crates/cubecl-core/src/runtime_tests/topology.rs @@ -22,13 +22,15 @@ pub fn test_kernel_topology_absolute_pos(client: ComputeClient()); let handle2 = client.empty(length as usize * core::mem::size_of::()); - kernel_absolute_pos::launch::( - &client, - CubeCount::Static(cube_count.0, cube_count.1, cube_count.2), - CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2), - ArrayArg::new(&handle1, length as usize), - ArrayArg::new(&handle2, length as usize), - ); + unsafe { + kernel_absolute_pos::launch::( + &client, + CubeCount::Static(cube_count.0, cube_count.1, cube_count.2), + CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2), + ArrayArg::from_raw_parts(&handle1, length as usize, 1), + ArrayArg::from_raw_parts(&handle2, length as usize, 1), + ) + }; let actual = client.read(handle1.binding()); let actual = u32::from_bytes(&actual); diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index 40adcd380..5c208cc0b 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -91,7 +91,7 @@ pub fn compute_loop_k_test(device: &R::Device) { cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), - ArrayArg::new(&results, m * n), + ArrayArg::from_raw_parts(&results, m * n, 1), UInt::new(m as u32), UInt::new(k as u32), UInt::new(n as u32), @@ -169,7 +169,7 @@ pub fn compute_loop_warp_test(device: &R::Device) { cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), - ArrayArg::new(&results, m * n), + ArrayArg::from_raw_parts(&results, m * n, 1), UInt::new(m as u32), UInt::new(k as u32), UInt::new(n as u32), @@ -276,7 +276,7 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), - ArrayArg::new(&results, m * n), + ArrayArg::from_raw_parts(&results, m * n, 1), UInt::new(m as u32), UInt::new(k as u32), UInt::new(n as u32), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index cf90bccda..33521c561 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -104,7 +104,7 @@ pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -163,7 +163,7 @@ pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { &rhs_tensor.shape, 4, ), - ArrayArg::new(&rhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -222,7 +222,7 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -286,7 +286,7 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(12), ScalarArg::new(64), @@ -348,7 +348,7 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(12), ScalarArg::new(12), @@ -410,7 +410,7 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(12), ScalarArg::new(12), @@ -471,7 +471,7 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { &rhs_tensor.shape, 4, ), - ArrayArg::new(&rhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -535,7 +535,7 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -598,7 +598,7 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { &rhs_tensor.shape, 4, ), - ArrayArg::new(&rhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -664,7 +664,7 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -730,7 +730,7 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { &rhs_tensor.shape, 4, ), - ArrayArg::new(&rhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), ScalarArg::new(0), ScalarArg::new(64), ScalarArg::new(64), @@ -793,7 +793,7 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { &lhs_tensor.shape, 4, ), - ArrayArg::new(&lhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&lhs_sm, 64 * 32, 1), ScalarArg::new(32), ScalarArg::new(64), ScalarArg::new(64), @@ -856,7 +856,7 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { &rhs_tensor.shape, 4, ), - ArrayArg::new(&rhs_sm, 64 * 32), + ArrayArg::from_raw_parts(&rhs_sm, 64 * 32, 1), ScalarArg::new(32), ScalarArg::new(64), ScalarArg::new(64), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index 2f090ac27..c9133eca6 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -66,7 +66,7 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, @@ -134,7 +134,7 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, @@ -212,7 +212,7 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, @@ -285,7 +285,7 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, @@ -358,7 +358,7 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, @@ -427,7 +427,7 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, @@ -545,7 +545,7 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, 4), - ArrayArg::new(&acc_sm.handle, 64 * 64), + ArrayArg::from_raw_parts(&acc_sm.handle, 64 * 64, 1), ScalarArg::new(m as u32), ScalarArg::new(n as u32), config, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index bcd7fda7e..7a3db32bd 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -55,9 +55,9 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) &client, cube_count, cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), + ArrayArg::from_raw_parts(®ister_m, 4, 1), + ArrayArg::from_raw_parts(®ister_n, 4, 1), + ArrayArg::from_raw_parts(&results, 16, 1), config, ); }; @@ -131,9 +131,9 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { &client, cube_count, cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), + ArrayArg::from_raw_parts(®ister_m, 4, 1), + ArrayArg::from_raw_parts(®ister_n, 4, 1), + ArrayArg::from_raw_parts(&results, 16, 1), config, ); }; @@ -165,7 +165,7 @@ pub fn compute_loop_unit_test(device: &R::Device) { TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ScalarArg::new(0), ScalarArg::new(0), - ArrayArg::new(&results, 16), + ArrayArg::from_raw_parts(&results, 16, 1), UInt::new(16), UInt::new(16), config, @@ -199,7 +199,7 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), - ArrayArg::new(&results, 16), + ArrayArg::from_raw_parts(&results, 16, 1), UInt::new(8), UInt::new(8), config, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index a9a3e07dd..1ea5a3fdc 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -228,7 +228,7 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ScalarArg::new(8), @@ -268,7 +268,7 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic &lhs.shape, vectorization_factor as u8, ), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(0), config, true, @@ -300,7 +300,7 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(0), config, true, @@ -333,7 +333,7 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(8), config, true, @@ -366,7 +366,7 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ScalarArg::new(8), @@ -400,7 +400,7 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(0), config, false, @@ -433,7 +433,7 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, TILE_SIZE as u8), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(8), config, false, @@ -466,7 +466,7 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ScalarArg::new(8), @@ -501,7 +501,7 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ScalarArg::new(8), @@ -535,7 +535,7 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ScalarArg::new(8), @@ -570,7 +570,7 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic cube_count, cube_dim, TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ArrayArg::from_raw_parts(&sm_out, 64, TILE_SIZE as u8), ScalarArg::new(4), ScalarArg::new(4), ScalarArg::new(8), diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index 6d83ed888..41c2f2931 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -72,7 +72,7 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, TILE_SIZE as u8), - ArrayArg::new(&tile.handle, 16), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), config, ); }; @@ -101,7 +101,7 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, TILE_SIZE as u8), - ArrayArg::new(&tile.handle, 16), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), config, ); }; @@ -130,7 +130,7 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization as u8), - ArrayArg::new(&tile.handle, 16), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), config, ); }; @@ -161,7 +161,7 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization as u8), - ArrayArg::new(&tile.handle, 16), + ArrayArg::from_raw_parts(&tile.handle, 16, 1), config, ); }; @@ -192,7 +192,7 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De cube_count, cube_dim, TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization), - ArrayArg::new(&results.handle, 16), + ArrayArg::from_raw_parts(&results.handle, 16, 1), config, ); }; diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index 75dd038f3..8d37e1bee 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -133,12 +133,14 @@ where cube_dim, ); - init::zeros_array::launch::( - client, - cube_count, - cube_dim, - ArrayArg::vectorized(vectorization_factor, &handle, num_elements), - ); + unsafe { + init::zeros_array::launch_unchecked::( + client, + cube_count, + cube_dim, + ArrayArg::from_raw_parts(&handle, num_elements, vectorization_factor), + ) + }; Self::new(shape, strides, handle) } @@ -148,7 +150,7 @@ pub(crate) mod init { use cubecl::prelude::*; use cubecl_core as cubecl; - #[cube(launch)] + #[cube(launch_unchecked)] pub fn zeros_array(output: &mut Array) { if ABSOLUTE_POS < output.len() { output[ABSOLUTE_POS] = C::from_int(0); diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 79c834db4..80d76c949 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -1,6 +1,6 @@ use cubecl::prelude::*; -#[cube(launch)] +#[cube(launch_unchecked)] fn gelu_array(input: &Array, output: &mut Array) { if ABSOLUTE_POS < input.len() { output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); @@ -18,13 +18,15 @@ pub fn launch(device: &R::Device) { let output_handle = client.empty(input.len() * core::mem::size_of::()); let input_handle = client.create(f32::as_bytes(input)); - gelu_array::launch::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - ArrayArg::new(&input_handle, input.len()), - ArrayArg::new(&output_handle, input.len()), - ); + unsafe { + gelu_array::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32, 1, 1), + ArrayArg::from_raw_parts(&input_handle, input.len(), 1), + ArrayArg::from_raw_parts(&output_handle, input.len(), 1), + ) + }; let bytes = client.read(output_handle.binding()); let output = f32::from_bytes(&bytes);