diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index ab1dd7c98..bbbe24852 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -12,7 +12,7 @@ use crate::{ base::cmma_kernel, config::{cmma_cube_count, cmma_cube_dim, CmmaConfig, CmmaLaunchConfig}, }, - tensor::{matrix_layout, MatrixLayout, TensorHandle}, + tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle}, }; /// Matrix multiplication using [cooperative matrix-multiply and accumulate operations](cubecl_core::cmma). @@ -28,8 +28,7 @@ pub fn matmul_cmma( #[derive(Debug)] pub enum UnavailabilityReason { - TransposedInput, // TODO: Support that case. - NotMultipleOf4, // TODO: Support that case. + NotMultipleOf4, // TODO: Support that case. HiglyPermutatedInput, ShapeMemoryLimitBusted, InvalidConfig(String), @@ -43,15 +42,6 @@ pub fn check_cmma_availability( rhs: &TensorHandleRef<'_, R>, config: Option<&CmmaLaunchConfig>, ) -> Result<(), UnavailabilityReason> { - let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { - MatrixLayout::Contiguous => Ok(()), - MatrixLayout::MildlyPermuted { - transposed: _, - batch_swap: _, - } => Err(UnavailabilityReason::TransposedInput), - MatrixLayout::HighlyPermuted => Err(UnavailabilityReason::HiglyPermutatedInput), - }; - if !client.features().enabled(Feature::Cmma { a: Elem::Float(FloatKind::F16), b: Elem::Float(FloatKind::F16), @@ -63,9 +53,6 @@ pub fn check_cmma_availability( return Err(UnavailabilityReason::CmmaInstructionsUnsupported); } - check_layout(lhs)?; - check_layout(rhs)?; - let rank = lhs.shape.len(); let m = lhs.shape[rank - 2]; let k = lhs.shape[rank - 1]; @@ -105,6 +92,47 @@ pub fn matmul_cmma_ref( lhs: TensorHandleRef<'_, R>, rhs: TensorHandleRef<'_, R>, out: TensorHandleRef<'_, R>, +) { + let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { + MatrixLayout::Contiguous => true, + MatrixLayout::MildlyPermuted { + transposed: _, + batch_swap: _, + } => false, + MatrixLayout::HighlyPermuted => false, + }; + + let lhs_correct_layout = check_layout(&lhs); + let rhs_correct_layout = check_layout(&rhs); + + match (lhs_correct_layout, rhs_correct_layout) { + (true, true) => matmul_cmma_ref_no_check::(client, lhs, rhs, out), + (true, false) => matmul_cmma_ref_no_check::( + client, + lhs, + into_contiguous::(client, rhs).as_ref(), + out, + ), + (false, true) => matmul_cmma_ref_no_check::( + client, + into_contiguous::(client, lhs).as_ref(), + rhs, + out, + ), + (false, false) => matmul_cmma_ref_no_check::( + client, + into_contiguous::(client, lhs).as_ref(), + into_contiguous::(client, rhs).as_ref(), + out, + ), + } +} + +fn matmul_cmma_ref_no_check( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, ) { let rank = lhs.strides.len(); diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index c77aff672..5499d3e81 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -7,7 +7,7 @@ use crate::{ base::tiling2d_cube_kernel, config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig}, }, - tensor::{matrix_layout, MatrixLayout, TensorHandle}, + tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle}, }; use super::config::Tiling2dConfig; @@ -38,6 +38,51 @@ pub fn matmul_tiling_2d_ref( <= ::max_shared_memory_size(), "Shared memory limit will be busted. " ); + let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { + MatrixLayout::Contiguous => true, + MatrixLayout::MildlyPermuted { + transposed: _, + batch_swap: _, + } => true, + MatrixLayout::HighlyPermuted => false, + }; + let lhs_correct_layout = check_layout(&lhs); + let rhs_correct_layout = check_layout(&rhs); + + match (lhs_correct_layout, rhs_correct_layout) { + (true, true) => matmul_tiling_2d_ref_no_check::(client, lhs, rhs, out, config), + (true, false) => matmul_tiling_2d_ref_no_check::( + client, + lhs, + into_contiguous::(client, rhs).as_ref(), + out, + config, + ), + (false, true) => matmul_tiling_2d_ref_no_check::( + client, + into_contiguous::(client, lhs).as_ref(), + rhs, + out, + config, + ), + (false, false) => matmul_tiling_2d_ref_no_check::( + client, + into_contiguous::(client, lhs).as_ref(), + into_contiguous::(client, rhs).as_ref(), + out, + config, + ), + } +} + +/// Matrix multiplication using tiling 2d algorithm. +fn matmul_tiling_2d_ref_no_check( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, + config: Tiling2dConfig, +) { let rank = lhs.strides.len(); let m = lhs.shape[rank - 2]; @@ -58,7 +103,8 @@ pub fn matmul_tiling_2d_ref( let rhs_transposed = check_layout(rhs.strides); let vectorization = |shape: usize| { - [].into_iter() + [4, 2] + .into_iter() .filter(|v| shape % v == 0) .map(|v| v as u8) .next() diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs new file mode 100644 index 000000000..87e344395 --- /dev/null +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -0,0 +1,95 @@ +use cubecl_core::{ + self as cubecl, calculate_cube_count_elemwise, tensor_vectorization_factor, SUBCUBE_DIM_APPROX, +}; + +use cubecl::prelude::*; + +use super::TensorHandle; + +/// Returns the offset of the tensor corresponding to the layout tensor. +#[cube] +pub fn index_offset_with_layout( + tensor: &Tensor, + layout: &Tensor, + offset_layout: UInt, + dim_start: UInt, + dim_end: UInt, + unroll: Comptime, +) -> UInt { + let vectorization_factor = Comptime::vectorization(tensor); + let vectorization_factor_runtime = Comptime::runtime(vectorization_factor); + + let offset_ref = offset_layout * vectorization_factor_runtime; + let mut offset = UInt::new(0); + + for i in range(dim_start, dim_end, unroll) { + let ogwl = offset_ref / layout.stride(i); + offset += ogwl % tensor.shape(i) * tensor.stride(i); + } + + offset / vectorization_factor_runtime +} + +#[cube(launch)] +fn into_contiguous_kernel( + input: &Tensor, + output: &mut Tensor, + rank: Comptime>, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + UInt::new(0), + Comptime::unwrap_or_else(rank, || output.rank()), + Comptime::is_some(rank), + ); + + output[offset_output] = input[offset_input]; +} + +/// Make a jit tensor contiguous. +pub fn into_contiguous( + client: &ComputeClient, + input: TensorHandleRef<'_, R>, +) -> TensorHandle { + // Vectorization is only enabled when the last dimension is contiguous. + let rank = input.strides.len(); + let vectorization_factor = + tensor_vectorization_factor(&[4, 2], &input.shape, &input.strides, rank - 1); + + let num_elems: usize = input.shape.iter().product(); + let cube_count = calculate_cube_count_elemwise( + num_elems / vectorization_factor as usize, + SUBCUBE_DIM_APPROX, + ); + let handle = client.empty(num_elems * E::as_elem().size()); + let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle); + + into_contiguous_kernel::launch::( + &client, + cube_count, + CubeDim::default(), + TensorArg::vectorized( + vectorization_factor, + &input.handle, + &input.strides, + &input.shape, + ), + TensorArg::vectorized( + vectorization_factor, + &output.handle, + &output.strides, + &output.shape, + ), + Some(UInt::new(rank as u32)), + ); + + output +} diff --git a/crates/cubecl-linalg/src/tensor/mod.rs b/crates/cubecl-linalg/src/tensor/mod.rs index c06bc8555..1cc0a8367 100644 --- a/crates/cubecl-linalg/src/tensor/mod.rs +++ b/crates/cubecl-linalg/src/tensor/mod.rs @@ -1,4 +1,7 @@ mod base; +mod contiguous; mod layout; + pub use base::*; +pub use contiguous::*; pub use layout::*;