Skip to content

Commit

Permalink
cleanup useless
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed Jul 18, 2024
1 parent 803213c commit b8639e8
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 289 deletions.
271 changes: 1 addition & 270 deletions crates/cubecl-lac/src/matmul/cmma/compute_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,104 +99,12 @@ pub mod tests {
use crate::matmul::{
cmma::base::{make_accumulators, SharedMemoriesExpand},
test_utils::{
assert_equals, assert_equals_range, cmma_available, create_empty, range_tensor_f16,
assert_equals, cmma_available, create_empty, range_tensor_f16,
},
};

use super::*;

#[cube]
pub fn cmma_computation<F: Float, FC: Float>(
lhs: &Slice<FC>,
rhs: &Slice<FC>,
out: &mut SliceMut<F>,
) {
let a = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::A,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
);
let b = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::B,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
);
let c = cmma::Matrix::<F>::new(
cmma::MatrixIdent::Accumulator,
16,
16,
16,
cmma::MatrixLayout::Undefined,
);
cmma::fill::<F>(&c, F::new(0.0));

cmma::load::<FC>(&a, lhs, UInt::new(16));
cmma::load::<FC>(&b, rhs, UInt::new(16));

cmma::execute::<FC, FC, F, F>(&a, &b, &c, &c);

cmma::store::<F>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
}

#[cube(launch)]
fn compute_loop_cmma_test<F: Float, FC: Float>(
lhs: &Tensor<FC>,
rhs: &Tensor<FC>,
result: &mut Array<F>,
) {
cmma_computation(lhs.as_slice(), rhs.as_slice(), result.as_slice_mut());
}

#[cube(launch)]
fn compute_loop_cmma_offseted_slice_test<F: Float, FC: Float>(
lhs: &Tensor<FC>,
rhs: &Tensor<FC>,
result: &mut Array<F>,
) {
cmma_computation(
lhs.slice(256, 512),
rhs.slice(256, 512),
result.slice_mut(768, 1024),
);
}

#[cube(launch)]
fn compute_loop_cmma_offseted_slice_in_shared_memory_test<F: Float, FC: Float>(
lhs_tensor: &Tensor<FC>,
rhs_tensor: &Tensor<FC>,
accumulate_array: &mut Array<F>,
m: Comptime<UInt>,
k: Comptime<UInt>,
n: Comptime<UInt>,
) {
let mut lhs = SharedMemory::<FC>::new(Comptime::get(m * k));
let mut rhs = SharedMemory::<FC>::new(Comptime::get(k * n));
let mut accumulate = SharedMemory::<F>::new(Comptime::get(m * n));
for i in range(0u32, Comptime::get(m * k), Comptime::new(false)) {
lhs[i] = lhs_tensor[i];
}
for i in range(0u32, Comptime::get(k * n), Comptime::new(false)) {
rhs[i] = rhs_tensor[i];
}
for i in range(0u32, Comptime::get(m * n), Comptime::new(false)) {
accumulate[i] = F::new(0.);
}

cmma_computation(
lhs.slice(256, 512),
rhs.slice(256, 512),
accumulate.slice_mut(768, 1024),
);

for i in range(0u32, Comptime::get(m * n), Comptime::new(false)) {
accumulate_array[i] = accumulate[i];
}
}

#[cube(launch)]
fn compute_loop_test<F: Float, FC: Float>(
lhs_tensor: &Tensor<FC>,
Expand Down Expand Up @@ -243,183 +151,6 @@ pub mod tests {
);
}

/// Exported test
pub fn cmma_warp_test<R: Runtime>(device: &R::Device) {
if !cmma_available::<R>(device) {
// We can't execute the test, skip.
return;
}

let lhs = range_tensor_f16::<R>(16, 16, device);
let rhs = range_tensor_f16::<R>(16, 16, device);
let results = create_empty::<R>(16, 16, device);
let cube_dim = CubeDim::new(32, 1, 1);
let cube_count = CubeCount::Static(1, 1, 1);

compute_loop_cmma_test::launch::<F32, F16, R>(
R::client(device),
cube_count,
cube_dim,
TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape),
TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape),
ArrayArg::new(&results, 256),
);

let expected = &[
19840., 19960., 20080., 20200., 20320., 20440., 20560., 20680., 20800., 20920., 21040.,
21160., 21280., 21400., 21520., 21640., 50560., 50936., 51312., 51688., 52064., 52440.,
52816., 53192., 53568., 53944., 54320., 54696., 55072., 55448., 55824., 56200., 81280.,
81912., 82544., 83176., 83808., 84440., 85072., 85704., 86336., 86968., 87600., 88232.,
88864., 89496., 90128., 90760., 112000., 112888., 113776., 114664., 115552., 116440.,
117328., 118216., 119104., 119992., 120880., 121768., 122656., 123544., 124432.,
125320., 142720., 143864., 145008., 146152., 147296., 148440., 149584., 150728.,
151872., 153016., 154160., 155304., 156448., 157592., 158736., 159880., 173440.,
174840., 176240., 177640., 179040., 180440., 181840., 183240., 184640., 186040.,
187440., 188840., 190240., 191640., 193040., 194440., 204160., 205816., 207472.,
209128., 210784., 212440., 214096., 215752., 217408., 219064., 220720., 222376.,
224032., 225688., 227344., 229000., 234880., 236792., 238704., 240616., 242528.,
244440., 246352., 248264., 250176., 252088., 254000., 255912., 257824., 259736.,
261648., 263560., 265600., 267768., 269936., 272104., 274272., 276440., 278608.,
280776., 282944., 285112., 287280., 289448., 291616., 293784., 295952., 298120.,
296320., 298744., 301168., 303592., 306016., 308440., 310864., 313288., 315712.,
318136., 320560., 322984., 325408., 327832., 330256., 332680., 327040., 329720.,
332400., 335080., 337760., 340440., 343120., 345800., 348480., 351160., 353840.,
356520., 359200., 361880., 364560., 367240., 357760., 360696., 363632., 366568.,
369504., 372440., 375376., 378312., 381248., 384184., 387120., 390056., 392992.,
395928., 398864., 401800., 388480., 391672., 394864., 398056., 401248., 404440.,
407632., 410824., 414016., 417208., 420400., 423592., 426784., 429976., 433168.,
436360., 419200., 422648., 426096., 429544., 432992., 436440., 439888., 443336.,
446784., 450232., 453680., 457128., 460576., 464024., 467472., 470920., 449920.,
453624., 457328., 461032., 464736., 468440., 472144., 475848., 479552., 483256.,
486960., 490664., 494368., 498072., 501776., 505480., 480640., 484600., 488560.,
492520., 496480., 500440., 504400., 508360., 512320., 516280., 520240., 524200.,
528160., 532120., 536080., 540040.,
];

assert_equals::<R>(results, expected, device);
}

/// Exported test
pub fn compute_loop_cmma_offseted_warp_test<R: Runtime>(device: &R::Device) {
if !cmma_available::<R>(device) {
// We can't execute the test, skip.
return;
}

let lhs = range_tensor_f16::<R>(32, 16, device);
let rhs = range_tensor_f16::<R>(16, 32, device);
let results = create_empty::<R>(32, 32, device);
let cube_dim = CubeDim::new(32, 1, 1);
let cube_count = CubeCount::Static(1, 1, 1);

compute_loop_cmma_offseted_slice_test::launch::<F32, F16, R>(
R::client(device),
cube_count,
cube_dim,
TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape),
TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape),
ArrayArg::new(&results, 256),
);

let expected = &[
1590656.0, 1594872.0, 1599088.0, 1603304.0, 1607520.0, 1611736.0, 1615952.0, 1620168.0,
1624384.0, 1628600.0, 1632816.0, 1637032.0, 1641248.0, 1645464.0, 1649680.0, 1653896.0,
1686912.0, 1691384.0, 1695856.0, 1700328.0, 1704800.0, 1709272.0, 1713744.0, 1718216.0,
1722688.0, 1727160.0, 1731632.0, 1736104.0, 1740576.0, 1745048.0, 1749520.0, 1753992.0,
1783168.0, 1787896.0, 1792624.0, 1797352.0, 1802080.0, 1806808.0, 1811536.0, 1816264.0,
1820992.0, 1825720.0, 1830448.0, 1835176.0, 1839904.0, 1844632.0, 1849360.0, 1854088.0,
1879424.0, 1884408.0, 1889392.0, 1894376.0, 1899360.0, 1904344.0, 1909328.0, 1914312.0,
1919296.0, 1924280.0, 1929264.0, 1934248.0, 1939232.0, 1944216.0, 1949200.0, 1954184.0,
1975680.0, 1980920.0, 1986160.0, 1991400.0, 1996640.0, 2001880.0, 2007120.0, 2012360.0,
2017600.0, 2022840.0, 2028080.0, 2033320.0, 2038560.0, 2043800.0, 2049040.0, 2054280.0,
2071936.0, 2077432.0, 2082928.0, 2088424.0, 2093920.0, 2099416.0, 2104912.0, 2110408.0,
2115904.0, 2121400.0, 2126896.0, 2132392.0, 2137888.0, 2143384.0, 2148880.0, 2154376.0,
2168192.0, 2173944.0, 2179696.0, 2185448.0, 2191200.0, 2196952.0, 2202704.0, 2208456.0,
2214208.0, 2219960.0, 2225712.0, 2231464.0, 2237216.0, 2242968.0, 2248720.0, 2254472.0,
2264448.0, 2270456.0, 2276464.0, 2282472.0, 2288480.0, 2294488.0, 2300496.0, 2306504.0,
2312512.0, 2318520.0, 2324528.0, 2330536.0, 2336544.0, 2342552.0, 2348560.0, 2354568.0,
2360704.0, 2366968.0, 2373232.0, 2379496.0, 2385760.0, 2392024.0, 2398288.0, 2404552.0,
2410816.0, 2417080.0, 2423344.0, 2429608.0, 2435872.0, 2442136.0, 2448400.0, 2454664.0,
2456960.0, 2463480.0, 2470000.0, 2476520.0, 2483040.0, 2489560.0, 2496080.0, 2502600.0,
2509120.0, 2515640.0, 2522160.0, 2528680.0, 2535200.0, 2541720.0, 2548240.0, 2554760.0,
2553216.0, 2559992.0, 2566768.0, 2573544.0, 2580320.0, 2587096.0, 2593872.0, 2600648.0,
2607424.0, 2614200.0, 2620976.0, 2627752.0, 2634528.0, 2641304.0, 2648080.0, 2654856.0,
2649472.0, 2656504.0, 2663536.0, 2670568.0, 2677600.0, 2684632.0, 2691664.0, 2698696.0,
2705728.0, 2712760.0, 2719792.0, 2726824.0, 2733856.0, 2740888.0, 2747920.0, 2754952.0,
2745728.0, 2753016.0, 2760304.0, 2767592.0, 2774880.0, 2782168.0, 2789456.0, 2796744.0,
2804032.0, 2811320.0, 2818608.0, 2825896.0, 2833184.0, 2840472.0, 2847760.0, 2855048.0,
2841984.0, 2849528.0, 2857072.0, 2864616.0, 2872160.0, 2879704.0, 2887248.0, 2894792.0,
2902336.0, 2909880.0, 2917424.0, 2924968.0, 2932512.0, 2940056.0, 2947600.0, 2955144.0,
2938240.0, 2946040.0, 2953840.0, 2961640.0, 2969440.0, 2977240.0, 2985040.0, 2992840.0,
3000640.0, 3008440.0, 3016240.0, 3024040.0, 3031840.0, 3039640.0, 3047440.0, 3055240.0,
3034496.0, 3042552.0, 3050608.0, 3058664.0, 3066720.0, 3074776.0, 3082832.0, 3090888.0,
3098944.0, 3107000.0, 3115056.0, 3123112.0, 3131168.0, 3139224.0, 3147280.0, 3155336.0,
];
assert_equals_range::<R>(results, expected, 768..1024, device);
}

/// Exported test
pub fn compute_loop_cmma_offseted_warp_in_shared_memory_test<R: Runtime>(device: &R::Device) {
if !cmma_available::<R>(device) {
// We can't execute the test, skip.
return;
}

let lhs = range_tensor_f16::<R>(32, 16, device);
let rhs = range_tensor_f16::<R>(16, 32, device);
let results = create_empty::<R>(32, 32, device);
let cube_dim = CubeDim::new(32, 1, 1);
let cube_count = CubeCount::Static(1, 1, 1);

compute_loop_cmma_offseted_slice_in_shared_memory_test::launch::<F32, F16, R>(
R::client(device),
cube_count,
cube_dim,
TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape),
TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape),
ArrayArg::new(&results, 256),
UInt::new(32),
UInt::new(16),
UInt::new(32),
);

let expected = &[
1590656.0, 1594872.0, 1599088.0, 1603304.0, 1607520.0, 1611736.0, 1615952.0, 1620168.0,
1624384.0, 1628600.0, 1632816.0, 1637032.0, 1641248.0, 1645464.0, 1649680.0, 1653896.0,
1686912.0, 1691384.0, 1695856.0, 1700328.0, 1704800.0, 1709272.0, 1713744.0, 1718216.0,
1722688.0, 1727160.0, 1731632.0, 1736104.0, 1740576.0, 1745048.0, 1749520.0, 1753992.0,
1783168.0, 1787896.0, 1792624.0, 1797352.0, 1802080.0, 1806808.0, 1811536.0, 1816264.0,
1820992.0, 1825720.0, 1830448.0, 1835176.0, 1839904.0, 1844632.0, 1849360.0, 1854088.0,
1879424.0, 1884408.0, 1889392.0, 1894376.0, 1899360.0, 1904344.0, 1909328.0, 1914312.0,
1919296.0, 1924280.0, 1929264.0, 1934248.0, 1939232.0, 1944216.0, 1949200.0, 1954184.0,
1975680.0, 1980920.0, 1986160.0, 1991400.0, 1996640.0, 2001880.0, 2007120.0, 2012360.0,
2017600.0, 2022840.0, 2028080.0, 2033320.0, 2038560.0, 2043800.0, 2049040.0, 2054280.0,
2071936.0, 2077432.0, 2082928.0, 2088424.0, 2093920.0, 2099416.0, 2104912.0, 2110408.0,
2115904.0, 2121400.0, 2126896.0, 2132392.0, 2137888.0, 2143384.0, 2148880.0, 2154376.0,
2168192.0, 2173944.0, 2179696.0, 2185448.0, 2191200.0, 2196952.0, 2202704.0, 2208456.0,
2214208.0, 2219960.0, 2225712.0, 2231464.0, 2237216.0, 2242968.0, 2248720.0, 2254472.0,
2264448.0, 2270456.0, 2276464.0, 2282472.0, 2288480.0, 2294488.0, 2300496.0, 2306504.0,
2312512.0, 2318520.0, 2324528.0, 2330536.0, 2336544.0, 2342552.0, 2348560.0, 2354568.0,
2360704.0, 2366968.0, 2373232.0, 2379496.0, 2385760.0, 2392024.0, 2398288.0, 2404552.0,
2410816.0, 2417080.0, 2423344.0, 2429608.0, 2435872.0, 2442136.0, 2448400.0, 2454664.0,
2456960.0, 2463480.0, 2470000.0, 2476520.0, 2483040.0, 2489560.0, 2496080.0, 2502600.0,
2509120.0, 2515640.0, 2522160.0, 2528680.0, 2535200.0, 2541720.0, 2548240.0, 2554760.0,
2553216.0, 2559992.0, 2566768.0, 2573544.0, 2580320.0, 2587096.0, 2593872.0, 2600648.0,
2607424.0, 2614200.0, 2620976.0, 2627752.0, 2634528.0, 2641304.0, 2648080.0, 2654856.0,
2649472.0, 2656504.0, 2663536.0, 2670568.0, 2677600.0, 2684632.0, 2691664.0, 2698696.0,
2705728.0, 2712760.0, 2719792.0, 2726824.0, 2733856.0, 2740888.0, 2747920.0, 2754952.0,
2745728.0, 2753016.0, 2760304.0, 2767592.0, 2774880.0, 2782168.0, 2789456.0, 2796744.0,
2804032.0, 2811320.0, 2818608.0, 2825896.0, 2833184.0, 2840472.0, 2847760.0, 2855048.0,
2841984.0, 2849528.0, 2857072.0, 2864616.0, 2872160.0, 2879704.0, 2887248.0, 2894792.0,
2902336.0, 2909880.0, 2917424.0, 2924968.0, 2932512.0, 2940056.0, 2947600.0, 2955144.0,
2938240.0, 2946040.0, 2953840.0, 2961640.0, 2969440.0, 2977240.0, 2985040.0, 2992840.0,
3000640.0, 3008440.0, 3016240.0, 3024040.0, 3031840.0, 3039640.0, 3047440.0, 3055240.0,
3034496.0, 3042552.0, 3050608.0, 3058664.0, 3066720.0, 3074776.0, 3082832.0, 3090888.0,
3098944.0, 3107000.0, 3115056.0, 3123112.0, 3131168.0, 3139224.0, 3147280.0, 3155336.0,
];
assert_equals_range::<R>(results, expected, 768..1024, device);
}

/// Exported test
pub fn compute_loop_k_test<R: Runtime>(device: &R::Device) {
if !cmma_available::<R>(device) {
Expand Down
19 changes: 0 additions & 19 deletions crates/cubecl-lac/src/tests/matmul_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,6 @@ macro_rules! testgen_matmul_internal {
cmma_compute_loop_tests::compute_loop_warp_test::<TestRuntime>(&Default::default())
}

#[test]
pub fn cmma_compute_loop_cmma_offseted_warp_test() {
cmma_compute_loop_tests::compute_loop_cmma_offseted_warp_test::<TestRuntime>(
&Default::default(),
)
}

#[test]
pub fn cmma_compute_loop_cmma_offseted_warp_in_shared_memory_test() {
cmma_compute_loop_tests::compute_loop_cmma_offseted_warp_in_shared_memory_test::<
TestRuntime,
>(&Default::default())
}

#[test]
pub fn cmma_warp_test() {
cmma_compute_loop_tests::cmma_warp_test::<TestRuntime>(&Default::default())
}

#[test]
pub fn cmma_load_shared_memory_lhs_unit_test() {
cmma_load_shared_memory_tests::load_shared_memory_lhs_unit_test::<TestRuntime>(
Expand Down

0 comments on commit b8639e8

Please sign in to comment.