diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index d2a19b6a3..ab1dd7c98 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -1,8 +1,10 @@ use std::cmp::max; use cubecl_core::{ - frontend::{Float, TensorArg, F16}, - Compiler, Runtime, + client::ComputeClient, + frontend::{Float, TensorArg, TensorHandleRef, F16}, + ir::{Elem, FloatKind}, + Compiler, Feature, Runtime, }; use crate::{ @@ -10,35 +12,105 @@ use crate::{ base::cmma_kernel, config::{cmma_cube_count, cmma_cube_dim, CmmaConfig, CmmaLaunchConfig}, }, - tensor::{MatrixLayout, TensorHandle}, + tensor::{matrix_layout, MatrixLayout, TensorHandle}, }; -/// Matrix multiplication using tiling 2d algorithm +/// Matrix multiplication using [cooperative matrix-multiply and accumulate operations](cubecl_core::cmma). pub fn matmul_cmma( + client: &ComputeClient, lhs: TensorHandle, rhs: TensorHandle, out: TensorHandle, - device: &R::Device, ) -> TensorHandle { - let rank = lhs.rank(); - let m = lhs.shape[rank - 2]; - let k = lhs.shape[rank - 1]; - let n = rhs.shape[rank - 1]; + matmul_cmma_ref::(client, lhs.as_ref(), rhs.as_ref(), out.as_ref()); + out +} - let client = R::client(device); +#[derive(Debug)] +pub enum UnavailabilityReason { + TransposedInput, // TODO: Support that case. + NotMultipleOf4, // TODO: Support that case. + HiglyPermutatedInput, + ShapeMemoryLimitBusted, + InvalidConfig(String), + CmmaInstructionsUnsupported, +} - let check_layout = |tensor: &TensorHandle| match tensor.matrix_layout() { - MatrixLayout::Contiguous => {} +/// Checks if the matmul cmma can be used. +pub fn check_cmma_availability( + client: &ComputeClient, + lhs: &TensorHandleRef<'_, R>, + rhs: &TensorHandleRef<'_, R>, + config: Option<&CmmaLaunchConfig>, +) -> Result<(), UnavailabilityReason> { + let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { + MatrixLayout::Contiguous => Ok(()), MatrixLayout::MildlyPermuted { transposed: _, batch_swap: _, - } => panic!("Transposed input not supported yet."), - MatrixLayout::HighlyPermuted => { - panic!("Can't run on highly permuted tensor.") - } + } => Err(UnavailabilityReason::TransposedInput), + MatrixLayout::HighlyPermuted => Err(UnavailabilityReason::HiglyPermutatedInput), }; - check_layout(&lhs); - check_layout(&rhs); + + if !client.features().enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }) { + return Err(UnavailabilityReason::CmmaInstructionsUnsupported); + } + + check_layout(lhs)?; + check_layout(rhs)?; + + let rank = lhs.shape.len(); + let m = lhs.shape[rank - 2]; + let k = lhs.shape[rank - 1]; + let n = rhs.shape[rank - 1]; + + if !(m % 4 == 0 && k % 4 == 0 && n % 4 == 0) { + return Err(UnavailabilityReason::NotMultipleOf4); + } + + if let Some(config) = config { + let (b_m, b_k, b_n) = ( + config.block_size_m, + config.block_size_k, + config.block_size_n, + ); + + if b_k * max(b_m, b_n) > ::max_shared_memory_size() { + return Err(UnavailabilityReason::ShapeMemoryLimitBusted); + } + + if b_m * b_n > ::max_shared_memory_size() { + return Err(UnavailabilityReason::ShapeMemoryLimitBusted); + } + + if b_k != 2 * config.tile_size { + return Err(UnavailabilityReason::InvalidConfig( + "Variable tile number per coop_units not supported".to_string(), + )); + } + } + + Ok(()) +} +/// Matrix multiplication using [cooperative matrix-multiply and accumulate operations](cubecl_core::cmma). +pub fn matmul_cmma_ref( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, +) { + let rank = lhs.strides.len(); + + let m = lhs.shape[rank - 2]; + let k = lhs.shape[rank - 1]; + let n = rhs.shape[rank - 1]; let vectorization = |shape: usize| { [4, 2] @@ -53,41 +125,17 @@ pub fn matmul_cmma( let rhs_vectorization = vectorization(n); let out_vectorization = vectorization(n); - let cube_count = cmma_cube_count::(&out.shape, 64, 64); + let cube_count = cmma_cube_count::(out.shape, 64, 64); let cube_dim = cmma_cube_dim(); let launch_config = CmmaLaunchConfig::default(); - let (b_m, b_k, b_n) = ( - launch_config.block_size_m, - launch_config.block_size_k, - launch_config.block_size_n, - ); - - assert!( - lhs_vectorization == 4 && rhs_vectorization == 4 && out_vectorization == 4, - "Only vec4 is supported" - ); - assert!( - b_k * max(b_m, b_n) <= ::max_shared_memory_size(), - "Shared memory limit will be busted. " - ); - assert!( - b_m * b_n <= ::max_shared_memory_size(), - "Shared memory limit will be busted. " - ); - assert!( - b_k == 2 * launch_config.tile_size, - "Variable tile number per coop_units not supported" - ); cmma_kernel::launch::( - &client, + client, cube_count, cube_dim, - TensorArg::vectorized(lhs_vectorization, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(rhs_vectorization, &rhs.handle, &rhs.strides, &rhs.shape), - TensorArg::vectorized(out_vectorization, &out.handle, &out.strides, &out.shape), + TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), + TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), + TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), CmmaConfig::new(m, k, n, launch_config), ); - - out } diff --git a/crates/cubecl-linalg/src/matmul/cmma/mod.rs b/crates/cubecl-linalg/src/matmul/cmma/mod.rs index e2151609a..7bfd49469 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/mod.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/mod.rs @@ -7,4 +7,6 @@ mod launch; pub(crate) mod load_shared_memory; pub(crate) mod write_output; -pub use launch::matmul_cmma; +pub use launch::check_cmma_availability as is_available; +pub use launch::matmul_cmma as launch; +pub use launch::matmul_cmma_ref as launch_ref; diff --git a/crates/cubecl-linalg/src/matmul/mod.rs b/crates/cubecl-linalg/src/matmul/mod.rs index 8b31a536c..6542c4b68 100644 --- a/crates/cubecl-linalg/src/matmul/mod.rs +++ b/crates/cubecl-linalg/src/matmul/mod.rs @@ -1,3 +1,5 @@ +use cubecl_core::prelude::*; + /// Contains algorithms for cooperative matrix multiplication. pub mod cmma; @@ -7,3 +9,17 @@ pub mod tiling2d; #[cfg(feature = "export_tests")] pub mod tests; + +/// Launch a matrix multiplication kernel. +pub fn launch_ref( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, +) { + if cmma::is_available(client, &lhs, &rhs, None).is_ok() { + cmma::launch_ref::(client, lhs, rhs, out); + } else { + tiling2d::launch_ref::(client, lhs, rhs, out, Default::default()); + } +} diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index 4f8161cca..de836ee8d 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -66,9 +66,10 @@ pub fn compute_loop_k_test(device: &R::Device) { let m = 16; let k = 32; let n = 16; - let lhs = range_tensor_f16::(m, k, device); - let rhs = range_tensor_f16::(k, n, device); - let results = create_empty::(m, n, device); + let client = R::client(device); + let lhs = range_tensor_f16::(&client, m, k); + let rhs = range_tensor_f16::(&client, k, n); + let results = create_empty::(&client, m, n); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -128,7 +129,7 @@ pub fn compute_loop_k_test(device: &R::Device) { 3659328., 3671344., 3683360., 3695376., ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } /// Exported test @@ -141,9 +142,10 @@ pub fn compute_loop_warp_test(device: &R::Device) { let m = 16; let k = 32; let n = 32; - let lhs = range_tensor_f16::(m, k, device); - let rhs = range_tensor_f16::(k, n, device); - let results = create_empty::(m, n, device); + let client = R::client(device); + let lhs = range_tensor_f16::(&client, m, k); + let rhs = range_tensor_f16::(&client, k, n); + let results = create_empty::(&client, m, n); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -231,7 +233,7 @@ pub fn compute_loop_warp_test(device: &R::Device) { 9763456., 9775472., 9787488., 9799504., 9811520., 9823536., 9835552., 9847568., ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } /// Exported test @@ -244,10 +246,11 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De let m = 16; let k = 32; let n = 64; + let client = R::client(device); - let lhs = range_tensor_f16::(m, k, device); - let rhs = range_tensor_f16::(k, n, device); - let results = create_empty::(m, n, device); + let lhs = range_tensor_f16::(&client, m, k); + let rhs = range_tensor_f16::(&client, k, n); + let results = create_empty::(&client, m, n); let cube_dim = CubeDim::new(32, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -263,7 +266,7 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De }; compute_loop_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -413,5 +416,5 @@ pub fn cmma_compute_loop_two_warps_same_tile_row_test(device: &R::De 22103888.0, 22115904.0, 22127920.0, 22139936.0, 22151952.0, ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index ab1542b5c..08bed0ed0 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -76,8 +76,9 @@ fn load_rhs_test( /// Exported test pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { - let lhs_tensor = range_tensor::(64, 64, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 64, 64); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -127,13 +128,14 @@ pub fn load_shared_memory_lhs_unit_test(device: &R::Device) { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals_range::(lhs_sm, expected, 0..256, device); + assert_equals_range::(&client, lhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { - let rhs_tensor = range_tensor::(64, 64, device); - let rhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let rhs_tensor = range_tensor::(&client, 64, 64); + let rhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -183,13 +185,14 @@ pub fn load_shared_memory_rhs_unit_test(device: &R::Device) { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals_range::(rhs_sm, expected, 0..256, device); + assert_equals_range::(&client, rhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { - let lhs_tensor = range_tensor::(64, 64, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 64, 64); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -244,13 +247,14 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { 960.0, 961.0, 962.0, 963.0, 964.0, 965.0, 966.0, 967.0, 968.0, 969.0, 970.0, 971.0, 972.0, 973.0, 974.0, 975.0, ]; - assert_equals_range::(lhs_sm, expected, 0..256, device); + assert_equals_range::(&client, lhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device: &R::Device) { - let lhs_tensor = range_tensor::(12, 64, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 12, 64); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -303,13 +307,14 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals_range::(lhs_sm, expected, 0..256, device); + assert_equals_range::(&client, lhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(device: &R::Device) { - let lhs_tensor = range_tensor::(64, 12, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 64, 12); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -325,7 +330,7 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi }; load_lhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -362,13 +367,14 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals_range::(lhs_sm, expected, 0..256, device); + assert_equals_range::(&client, lhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: &R::Device) { - let lhs_tensor = range_tensor::(12, 12, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 12, 12); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -384,7 +390,7 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & }; load_lhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -420,13 +426,14 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals_range::(lhs_sm, expected, 0..256, device); + assert_equals_range::(&client, lhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { - let rhs_tensor = range_tensor::(64, 64, device); - let rhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let rhs_tensor = range_tensor::(&client, 64, 64); + let rhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -442,7 +449,7 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { }; load_rhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -481,13 +488,14 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { 960.0, 961.0, 962.0, 963.0, 964.0, 965.0, 966.0, 967.0, 968.0, 969.0, 970.0, 971.0, 972.0, 973.0, 974.0, 975.0, ]; - assert_equals_range::(rhs_sm, expected, 0..256, device); + assert_equals_range::(&client, rhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { - let lhs_tensor = range_tensor::(64, 64, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 64, 64); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -503,7 +511,7 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { }; load_lhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -541,13 +549,14 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { ]; // We are testing second warp - assert_equals_range::(lhs_sm, expected, 256..512, device); + assert_equals_range::(&client, lhs_sm, expected, 256..512); } /// Exported test pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { - let rhs_tensor = range_tensor::(64, 64, device); - let rhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let rhs_tensor = range_tensor::(&client, 64, 64); + let rhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -563,7 +572,7 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { }; load_rhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -604,13 +613,14 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { ]; // We are testing second warp - assert_equals_range::(rhs_sm, expected, 256..512, device); + assert_equals_range::(&client, rhs_sm, expected, 256..512); } /// Exported test pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { - let lhs_tensor = range_tensor::(64, 64, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 64, 64); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 3, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -626,7 +636,7 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { }; load_lhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -667,13 +677,14 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { ]; // We are testing second warp - assert_equals_range::(lhs_sm, expected, 512..768, device); + assert_equals_range::(&client, lhs_sm, expected, 512..768); } /// Exported test pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { - let rhs_tensor = range_tensor::(64, 64, device); - let rhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let rhs_tensor = range_tensor::(&client, 64, 64); + let rhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 3, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -689,7 +700,7 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { }; load_rhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -727,13 +738,14 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { ]; // We are testing second warp - assert_equals_range::(rhs_sm, expected, 512..768, device); + assert_equals_range::(&client, rhs_sm, expected, 512..768); } /// Exported test pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { - let lhs_tensor = range_tensor::(64, 64, device); - let lhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let lhs_tensor = range_tensor::(&client, 64, 64); + let lhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -749,7 +761,7 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { }; load_lhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -787,13 +799,14 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { ]; // We are testing second warp - assert_equals_range::(lhs_sm, expected, 0..256, device); + assert_equals_range::(&client, lhs_sm, expected, 0..256); } /// Exported test pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { - let rhs_tensor = range_tensor::(64, 64, device); - let rhs_sm = create_empty::(32, 64, device); + let client = R::client(device); + let rhs_tensor = range_tensor::(&client, 64, 64); + let rhs_sm = create_empty::(&client, 32, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -809,7 +822,7 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { }; load_rhs_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -850,5 +863,5 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { ]; // We are testing second warp - assert_equals_range::(rhs_sm, expected, 0..256, device); + assert_equals_range::(&client, rhs_sm, expected, 0..256); } diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index e656849a3..d593eb013 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -43,8 +43,9 @@ fn write_output_test( pub fn cmma_write_output_unit_test(device: &R::Device) { let m = 16; let n = 32; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(1, 1, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -60,7 +61,7 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -101,15 +102,16 @@ pub fn cmma_write_output_unit_test(device: &R::Device) { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn cmma_write_output_warp_test(device: &R::Device) { let m = 16; let n = 32; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -125,7 +127,7 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -176,15 +178,16 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { 252.0, 253.0, 254.0, 255.0, 496.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: &R::Device) { let m = 16; let n = 28; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -200,7 +203,7 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -246,15 +249,16 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0, 496.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, 505.0, 506.0, 507.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R::Device) { let m = 14; let n = 32; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -270,7 +274,7 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -316,15 +320,16 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R 219.0, 220.0, 221.0, 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0, 475.0, 476.0, 477.0, 478.0, 479.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::Device) { let m = 14; let n = 28; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(32, 1, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -340,7 +345,7 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -382,15 +387,16 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D 221.0, 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0, 475.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn cmma_write_output_second_warp_test(device: &R::Device) { let m = 16; let n = 64; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(32, 2, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -406,7 +412,7 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -497,15 +503,16 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { 1011.0, 1012.0, 1013.0, 1014.0, 1015.0, 1016.0, 1017.0, 1018.0, 1019.0, 1020.0, 1021.0, 1022.0, 1023.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) { let m = 32; let n = 64; - let out = zeros_tensor::(m, n, device); - let acc_sm = range_tensor::(64, 64, device); + let client = R::client(device); + let out = zeros_tensor::(&client, m, n); + let acc_sm = range_tensor::(&client, 64, 64); let cube_dim = CubeDim::new(32, 4, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -521,7 +528,7 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) }; write_output_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(4, &out.handle, &out.strides, &out.shape), @@ -612,5 +619,5 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) 1785., 1786., 1787., 1788., 1789., 1790., 1791., 2032., 2033., 2034., 2035., 2036., 2037., 2038., 2039., 2040., 2041., 2042., 2043., 2044., 2045., 2046., 2047., ]; - assert_equals_range::(out.handle, expected, 1024..2048, device); + assert_equals_range::(&client, out.handle, expected, 1024..2048); } diff --git a/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs b/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs index 6d98839fa..e22a49d91 100644 --- a/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs +++ b/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs @@ -2,7 +2,7 @@ use cubecl_core::{frontend::F32, CubeElement, Runtime}; use half::f16; use crate::{ - matmul::{cmma::matmul_cmma, tiling2d::matmul_tiling_2d}, + matmul::{cmma::launch, tiling2d}, tensor::TensorHandle, }; @@ -126,13 +126,14 @@ struct MatmulTestCase { impl MatmulTestCase { fn test_tiling2d(&self, device: &R::Device) { + let client = R::client(device); let tensor_1 = - range_tensor_with_factor::(self.batch, self.m, self.k, self.factor, device); + range_tensor_with_factor::(&client, self.batch, self.m, self.k, self.factor); let tensor_2 = - range_tensor_with_factor::(self.batch, self.k, self.n, self.factor, device); + range_tensor_with_factor::(&client, self.batch, self.k, self.n, self.factor); let out = TensorHandle::new_contiguous( vec![self.batch, self.m, self.n], - create_empty::(self.batch * self.m, self.n, device), + create_empty::(&client, self.batch * self.m, self.n), ); let expected = self.matmul_cpu( @@ -140,9 +141,9 @@ impl MatmulTestCase { f32::from_bytes(&R::client(device).read(tensor_2.handle.clone().binding())), ); - let out = matmul_tiling_2d::(tensor_1, tensor_2, out, Default::default(), device); + let out = tiling2d::launch::(&client, tensor_1, tensor_2, out, Default::default()); - assert_equals_approx::(out.handle, &expected, self.epsilon, device); + assert_equals_approx::(&client, out.handle, &expected, self.epsilon); } fn test_cmma(&self, device: &R::Device) { @@ -151,23 +152,24 @@ impl MatmulTestCase { return; } + let client = R::client(device); let tensor_1 = - range_tensor_with_factor::(self.batch, self.m, self.k, self.factor, device); + range_tensor_with_factor::(&client, self.batch, self.m, self.k, self.factor); let tensor_2 = - range_tensor_with_factor::(self.batch, self.k, self.n, self.factor, device); + range_tensor_with_factor::(&client, self.batch, self.k, self.n, self.factor); let out = TensorHandle::new_contiguous( vec![self.batch, self.m, self.n], - create_empty::(self.batch * self.m, self.n, device), + create_empty::(&client, self.batch * self.m, self.n), ); let expected = self.matmul_cpu( - f32::from_bytes(&R::client(device).read(tensor_1.handle.clone().binding())), - f32::from_bytes(&R::client(device).read(tensor_2.handle.clone().binding())), + f32::from_bytes(&client.read(tensor_1.handle.clone().binding())), + f32::from_bytes(&client.read(tensor_2.handle.clone().binding())), ); - let out = matmul_cmma::(tensor_1, tensor_2, out, device); + let out = launch::(&client, tensor_1, tensor_2, out); - assert_equals_approx::(out.handle, &expected, self.epsilon, device); + assert_equals_approx::(&client, out.handle, &expected, self.epsilon); } fn matmul_cpu(&self, lhs: &[f32], rhs: &[f32]) -> Vec { diff --git a/crates/cubecl-linalg/src/matmul/tests/test_utils.rs b/crates/cubecl-linalg/src/matmul/tests/test_utils.rs index 9f6c302c7..26fdf8c50 100644 --- a/crates/cubecl-linalg/src/matmul/tests/test_utils.rs +++ b/crates/cubecl-linalg/src/matmul/tests/test_utils.rs @@ -1,5 +1,6 @@ use bytemuck::cast_slice; use cubecl_core::{ + client::ComputeClient, frontend::{F16, F32}, ir::{Elem, FloatKind}, server::Handle, @@ -13,12 +14,11 @@ use crate::{ }; pub(crate) fn range_tensor_f16( + client: &ComputeClient, x: usize, y: usize, - device: &R::Device, ) -> TensorHandle { let n_elements = x * y; - let client = R::client(device); let mut data = Vec::with_capacity(n_elements); for i in 0..n_elements { @@ -31,12 +31,11 @@ pub(crate) fn range_tensor_f16( } pub(crate) fn range_tensor( + client: &ComputeClient, x: usize, y: usize, - device: &R::Device, ) -> TensorHandle { let n_elements = x * y; - let client = R::client(device); let mut data: Vec = Vec::with_capacity(n_elements); for i in 0..n_elements { @@ -49,14 +48,13 @@ pub(crate) fn range_tensor( } pub(crate) fn range_tensor_with_factor( + client: &ComputeClient, batch: usize, x: usize, y: usize, factor: f32, - device: &R::Device, ) -> TensorHandle { let n_elements = batch * x * y; - let client = R::client(device); let mut data: Vec = Vec::with_capacity(n_elements); for i in 0..n_elements { @@ -69,12 +67,11 @@ pub(crate) fn range_tensor_with_factor( } pub(crate) fn range_tensor_transposed( + client: &ComputeClient, x: usize, y: usize, - device: &R::Device, ) -> TensorHandle { let n_elements = x * y; - let client = R::client(device); let mut data: Vec = Vec::with_capacity(n_elements); for i in 0..y { @@ -90,12 +87,11 @@ pub(crate) fn range_tensor_transposed( } pub(crate) fn zeros_tensor( + client: &ComputeClient, x: usize, y: usize, - device: &R::Device, ) -> TensorHandle { let n_elements = x * y; - let client = R::client(device); let data: Vec = vec![0.; n_elements]; let handle = client.create(cast_slice(&data)); @@ -104,21 +100,18 @@ pub(crate) fn zeros_tensor( } pub(crate) fn create_empty( + client: &ComputeClient, x: usize, y: usize, - device: &R::Device, ) -> Handle<::Server> { - let client = R::client(device); client.empty(x * y * core::mem::size_of::()) } pub(crate) fn assert_equals( + client: &ComputeClient, output: Handle<::Server>, expected: &[f32], - device: &R::Device, ) { - let client = R::client(device); - let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); @@ -126,13 +119,11 @@ pub(crate) fn assert_equals( } pub(crate) fn assert_equals_approx( + client: &ComputeClient, output: Handle<::Server>, expected: &[f32], epsilon: f32, - device: &R::Device, ) { - let client = R::client(device); - let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); @@ -149,13 +140,11 @@ pub(crate) fn assert_equals_approx( } pub(crate) fn assert_equals_range( + client: &ComputeClient, output: Handle<::Server>, expected: &[f32], range: Range, - device: &R::Device, ) { - let client = R::client(device); - let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index a67351483..21447b0cd 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -43,7 +43,7 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) let register_m = client.create(f32::as_bytes(&[16., 20., 24., 28.])); let register_n = client.create(f32::as_bytes(&[4., 5., 6., 7.])); - let results = create_empty::(4, 4, device); + let results = create_empty::(&client, 4, 4); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -64,7 +64,7 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) 64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0, 140.0, 168.0, 196.0, ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } #[cube(launch)] @@ -117,7 +117,7 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { let client = R::client(device); let register_m = client.create(f32::as_bytes(&[0., 1., 2., 3.])); let register_n = client.create(f32::as_bytes(&[1., 2., 3., 4.])); - let results = create_empty::(4, 4, device); + let results = create_empty::(&client, 4, 4); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -137,14 +137,15 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { let expected = &[ 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } /// Exported test pub fn compute_loop_unit_test(device: &R::Device) { - let lhs = range_tensor::(8, 8, device); - let rhs = range_tensor::(8, 8, device); - let results = create_empty::(TILE_SIZE, TILE_SIZE, device); + let client = R::client(device); + let lhs = range_tensor::(&client, 8, 8); + let rhs = range_tensor::(&client, 8, 8); + let results = create_empty::(&client, TILE_SIZE, TILE_SIZE); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -169,14 +170,15 @@ pub fn compute_loop_unit_test(device: &R::Device) { 8960.0, 9184.0, 9408.0, 9632.0, 9184.0, 9416.0, 9648.0, 9880.0, 9408.0, 9648.0, 9888.0, 10128.0, 9632.0, 9880.0, 10128.0, 10376.0, ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } /// Exported test pub fn compute_loop_unit_offset_test(device: &R::Device) { - let lhs = range_tensor_transposed::(8, 4, device); - let rhs = range_tensor::(4, 8, device); - let results = create_empty::(TILE_SIZE, TILE_SIZE, device); + let client = R::client(device); + let lhs = range_tensor_transposed::(&client, 8, 4); + let rhs = range_tensor::(&client, 4, 8); + let results = create_empty::(&client, TILE_SIZE, TILE_SIZE); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -200,5 +202,5 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { 1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0, 1978.0, 1928.0, 2046.0, 2164.0, 2282.0, ]; - assert_equals::(results, expected, device); + assert_equals::(&client, results, expected); } diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index 02bc37589..b8059b522 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -214,15 +214,16 @@ fn load_tensor_multiple_tiles_test( /// Exported test pub fn load_lhs_transposed_unit_test(device: &R::Device) { - let lhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let lhs = range_tensor::(&client, 16, 16); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(16, 16, 8); load_tensor_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -240,21 +241,22 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Device) { + let client = R::client(device); let vectorization_factor = 1; - let lhs = range_tensor::(5, 1, device); - let sm_out = create_empty::(8, 8, device); + let lhs = range_tensor::(&client, 5, 1); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(2, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(5, 1, 1); load_tensor_multiple_tiles_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized( @@ -275,20 +277,21 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_lhs_transposed_cube_test(device: &R::Device) { - let lhs = range_tensor::(8, 8, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let lhs = range_tensor::(&client, 8, 8); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(2, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(8, 8, 8); load_tensor_multiple_tiles_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -305,20 +308,21 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { 53.0, 61.0, 6.0, 14.0, 22.0, 30.0, 38.0, 46.0, 54.0, 62.0, 7.0, 15.0, 23.0, 31.0, 39.0, 47.0, 55.0, 63.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { - let lhs = range_tensor::(8, 16, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let lhs = range_tensor::(&client, 8, 16); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(2, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(8, 8, 16); load_tensor_multiple_tiles_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -335,20 +339,21 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { 61.0, 77.0, 93.0, 109.0, 125.0, 14.0, 30.0, 46.0, 62.0, 78.0, 94.0, 110.0, 126.0, 15.0, 31.0, 47.0, 63.0, 79.0, 95.0, 111.0, 127.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_rhs_plain_unit_test(device: &R::Device) { - let rhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let rhs = range_tensor::(&client, 16, 16); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(8, 16, 16); load_tensor_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), @@ -366,20 +371,21 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, 246.0, 247.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_rhs_plain_cube_test(device: &R::Device) { - let rhs = range_tensor::(8, 8, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let rhs = range_tensor::(&client, 8, 8); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(2, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(8, 8, 8); load_tensor_multiple_tiles_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), @@ -396,20 +402,21 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { - let rhs = range_tensor::(16, 8, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let rhs = range_tensor::(&client, 16, 8); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(2, 2, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(16, 16, 8); load_tensor_multiple_tiles_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape), @@ -426,20 +433,21 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_lhs_plain_unit_test(device: &R::Device) { - let lhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let lhs = range_tensor::(&client, 16, 16); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(16, 16, 8); load_tensor_permuted_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), @@ -457,14 +465,15 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, 246.0, 247.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { let (m, k) = (6, 14); - let lhs = range_tensor::(k, m, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let lhs = range_tensor::(&client, k, m); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -489,20 +498,21 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { 76.0, 77.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_rhs_transposed_unit_test(device: &R::Device) { - let rhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let rhs = range_tensor::(&client, 16, 16); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(16, 16, 8); load_tensor_permuted_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), @@ -520,21 +530,22 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } /// Exported test pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Device) { let (k, n) = (14, 6); - let rhs = range_tensor::(n, k, device); - let sm_out = create_empty::(8, 8, device); + let client = R::client(device); + let rhs = range_tensor::(&client, n, k); + let sm_out = create_empty::(&client, 8, 8); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = make_tiling2d_config(8, k, n); load_tensor_permuted_test::launch::( - &R::client(device), + &client, cube_count, cube_dim, TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), @@ -552,5 +563,5 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic 68.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 69.0, 83.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals::(sm_out, expected, device); + assert_equals::(&client, sm_out, expected); } diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index fd61da094..f0a9e5939 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -58,8 +58,9 @@ fn write_results_to_output_out_of_bounds_test( /// Exported test pub fn write_to_output_over_height_unit_test(device: &R::Device) { - let out = zeros_tensor::(6, 8, device); - let tile = range_tensor::(4, 4, device); + let client = R::client(device); + let out = zeros_tensor::(&client, 6, 8); + let tile = range_tensor::(&client, 4, 4); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -79,13 +80,14 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn write_to_output_over_width_unit_test(device: &R::Device) { - let out = zeros_tensor::(8, 4, device); - let tile = range_tensor::(4, 4, device); + let client = R::client(device); + let out = zeros_tensor::(&client, 8, 4); + let tile = range_tensor::(&client, 4, 4); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -104,14 +106,15 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn write_to_output_vectorized_less_than_tile_unit_test(device: &R::Device) { let vectorization = 2; - let out = zeros_tensor::(8, 8, device); - let tile = range_tensor::(4, 4, device); + let client = R::client(device); + let out = zeros_tensor::(&client, 8, 8); + let tile = range_tensor::(&client, 4, 4); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -132,14 +135,15 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn write_to_output_scalar_unit_test(device: &R::Device) { let vectorization = 1; - let out = zeros_tensor::(8, 8, device); - let tile = range_tensor::(4, 4, device); + let client = R::client(device); + let out = zeros_tensor::(&client, 8, 8); + let tile = range_tensor::(&client, 4, 4); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -160,14 +164,15 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, ]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } /// Exported test pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::Device) { let vectorization = 1; - let out = zeros_tensor::(5, 1, device); - let results = range_tensor_transposed::(4, 4, device); + let client = R::client(device); + let out = zeros_tensor::(&client, 5, 1); + let results = range_tensor_transposed::(&client, 4, 4); let cube_dim = CubeDim::new(2, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -183,5 +188,5 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De ); let expected = &[0.0, 1.0, 2.0, 3.0, 0.0]; - assert_equals::(out.handle, expected, device); + assert_equals::(&client, out.handle, expected); } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index 9ac0a4be4..c77aff672 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -7,48 +7,58 @@ use crate::{ base::tiling2d_cube_kernel, config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig}, }, - tensor::{MatrixLayout, TensorHandle}, + tensor::{matrix_layout, MatrixLayout, TensorHandle}, }; use super::config::Tiling2dConfig; -/// Matrix multiplication using tiling 2d algorithm +/// Matrix multiplication using tiling 2d algorithm. pub fn matmul_tiling_2d( + client: &ComputeClient, lhs: TensorHandle, rhs: TensorHandle, out: TensorHandle, config: Tiling2dConfig, - device: &R::Device, ) -> TensorHandle { + matmul_tiling_2d_ref::(client, lhs.as_ref(), rhs.as_ref(), out.as_ref(), config); + + out +} + +/// Matrix multiplication using tiling 2d algorithm. +pub fn matmul_tiling_2d_ref( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, + config: Tiling2dConfig, +) { assert!( config.block_size_k * max(config.block_size_m, config.block_size_n) <= ::max_shared_memory_size(), "Shared memory limit will be busted. " ); - let rank = lhs.rank(); + let rank = lhs.strides.len(); let m = lhs.shape[rank - 2]; let k = lhs.shape[rank - 1]; let n = rhs.shape[rank - 1]; - let client = R::client(device); - - let check_layout = |tensor: TensorHandle| match tensor.matrix_layout() { - MatrixLayout::Contiguous => (tensor, false), + let check_layout = |strides: &[usize]| match matrix_layout(strides) { + MatrixLayout::Contiguous => false, MatrixLayout::MildlyPermuted { transposed, batch_swap: _, - } => (tensor, transposed), + } => transposed, MatrixLayout::HighlyPermuted => { panic!("Can't run on highly permuted tensor") } }; - let (lhs, lhs_transposed) = check_layout(lhs); - let (rhs, rhs_transposed) = check_layout(rhs); + let lhs_transposed = check_layout(lhs.strides); + let rhs_transposed = check_layout(rhs.strides); let vectorization = |shape: usize| { - [4, 2] - .into_iter() + [].into_iter() .filter(|v| shape % v == 0) .map(|v| v as u8) .next() @@ -65,19 +75,17 @@ pub fn matmul_tiling_2d( }; let out_vectorization = vectorization(n); - let cube_count = tiling2d_cube_count::(&out.shape, &config); + let cube_count = tiling2d_cube_count::(out.shape, &config); let cube_dim = tiling2d_cube_dim(&config); let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); tiling2d_cube_kernel::launch::( - &client, + client, cube_count, cube_dim, - TensorArg::vectorized(lhs_vectorization, &lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::vectorized(rhs_vectorization, &rhs.handle, &rhs.strides, &rhs.shape), - TensorArg::vectorized(out_vectorization, &out.handle, &out.strides, &out.shape), + TensorArg::vectorized(lhs_vectorization, lhs.handle, lhs.strides, lhs.shape), + TensorArg::vectorized(rhs_vectorization, rhs.handle, rhs.strides, rhs.shape), + TensorArg::vectorized(out_vectorization, out.handle, out.strides, out.shape), cube_config, ); - - out } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/mod.rs b/crates/cubecl-linalg/src/matmul/tiling2d/mod.rs index ff080f048..b760ae859 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/mod.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/mod.rs @@ -8,4 +8,5 @@ pub(crate) mod outer_product; pub(crate) mod tile; pub(crate) mod write_output; -pub use launch::matmul_tiling_2d; +pub use launch::matmul_tiling_2d as launch; +pub use launch::matmul_tiling_2d_ref as launch_ref; diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index 43f96781c..5716c711c 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -6,8 +6,6 @@ use cubecl_core::SUBCUBE_DIM_APPROX; use cubecl_runtime::server::Handle; use std::marker::PhantomData; -use super::layout::{memory_layout, MatrixLayout}; - /// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata., pub struct TensorHandle where @@ -86,17 +84,12 @@ where self.handle.can_mut() } - /// Check if the current tensor is contiguous. - pub fn is_contiguous(&self) -> bool { - self.matrix_layout() == MatrixLayout::Contiguous - } - - pub(crate) fn matrix_layout(&self) -> MatrixLayout { - memory_layout(&self.strides) - } - - pub(crate) fn rank(&self) -> usize { - self.shape.len() + pub fn as_ref(&self) -> TensorHandleRef<'_, R> { + TensorHandleRef { + handle: &self.handle, + strides: &self.strides, + shape: &self.shape, + } } fn contiguous_strides(shape: &[usize]) -> Vec { diff --git a/crates/cubecl-linalg/src/tensor/layout.rs b/crates/cubecl-linalg/src/tensor/layout.rs index d6f0f4652..0b4e0b0a3 100644 --- a/crates/cubecl-linalg/src/tensor/layout.rs +++ b/crates/cubecl-linalg/src/tensor/layout.rs @@ -15,7 +15,7 @@ pub enum MatrixLayout { HighlyPermuted, } -pub fn memory_layout(strides: &[usize]) -> MatrixLayout { +pub fn matrix_layout(strides: &[usize]) -> MatrixLayout { let rank = strides.len(); if rank <= 1 { return MatrixLayout::Contiguous; @@ -59,13 +59,13 @@ mod tests { #[test] fn layout_is_contiguous() { let strides = &[8, 4, 2, 1]; - assert_eq!(memory_layout(strides), MatrixLayout::Contiguous); + assert_eq!(matrix_layout(strides), MatrixLayout::Contiguous); } #[test] fn vector_is_contiguous() { let strides = &[1]; - assert_eq!(memory_layout(strides), MatrixLayout::Contiguous) + assert_eq!(matrix_layout(strides), MatrixLayout::Contiguous) } #[test] @@ -74,7 +74,7 @@ mod tests { if let MatrixLayout::MildlyPermuted { transposed, batch_swap, - } = memory_layout(strides) + } = matrix_layout(strides) { assert!(transposed && !batch_swap); } else { @@ -88,7 +88,7 @@ mod tests { if let MatrixLayout::MildlyPermuted { transposed, batch_swap, - } = memory_layout(strides) + } = matrix_layout(strides) { assert!(!transposed && batch_swap); } else { @@ -102,7 +102,7 @@ mod tests { if let MatrixLayout::MildlyPermuted { transposed, batch_swap, - } = memory_layout(strides) + } = matrix_layout(strides) { assert!(transposed && batch_swap); } else { @@ -113,12 +113,12 @@ mod tests { #[test] fn layout_has_batch_swapped_with_row() { let strides = &[8, 2, 4, 1]; - assert_eq!(memory_layout(strides), MatrixLayout::HighlyPermuted); + assert_eq!(matrix_layout(strides), MatrixLayout::HighlyPermuted); } #[test] fn layout_has_batch_swapped_with_col() { let strides = &[1, 4, 2, 8]; - assert_eq!(memory_layout(strides), MatrixLayout::HighlyPermuted); + assert_eq!(matrix_layout(strides), MatrixLayout::HighlyPermuted); } } diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index b31a8886d..0d30b3c35 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -4,8 +4,7 @@ use std::marker::PhantomData; use cubecl::benchmark::Benchmark; use cubecl::client::SyncType; use cubecl::frontend::Float; -use cubecl_linalg::matmul::cmma::matmul_cmma; -use cubecl_linalg::matmul::tiling2d::matmul_tiling_2d; +use cubecl_linalg::matmul; use cubecl_linalg::tensor::TensorHandle; impl Benchmark for Tiling2dBench { @@ -24,22 +23,25 @@ impl Benchmark for Tiling2dBench { fn execute(&self, (lhs, rhs, out): Self::Args) { match self.kind { MatmulKind::Tiling2d => { - matmul_tiling_2d(lhs, rhs, out, Default::default(), &self.device); + matmul::tiling2d::launch(&self.client, lhs, rhs, out, Default::default()); } MatmulKind::Cmma => { - matmul_cmma(lhs, rhs, out, &self.device); + matmul::cmma::launch(&self.client, lhs, rhs, out); } } } + fn num_samples(&self) -> usize { + 100 + } + fn name(&self) -> String { let elem = E::as_elem(); format!("tiling2d-{}-{:?}-{:?}", R::name(), elem, self.kind) } fn sync(&self) { - let client = R::client(&self.device); - client.sync(SyncType::Wait); + self.client.sync(SyncType::Wait); } } @@ -51,6 +53,7 @@ struct Tiling2dBench { n: usize, kind: MatmulKind, device: R::Device, + client: ComputeClient, _e: PhantomData, } @@ -68,6 +71,7 @@ fn run(device: R::Device, kind: MatmulKind) { m: 1024, k: 1024, n: 1024, + client: R::client(&device), device, kind, _e: PhantomData,