From b8639e8131d7843a563af4f5f91932670fcab5cb Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 18 Jul 2024 14:03:11 -0400 Subject: [PATCH] cleanup useless --- .../src/matmul/cmma/compute_loop.rs | 271 +----------------- .../cubecl-lac/src/tests/matmul_internal.rs | 19 -- 2 files changed, 1 insertion(+), 289 deletions(-) diff --git a/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs b/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs index aa4feb4e9..2932b4d15 100644 --- a/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs +++ b/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs @@ -99,104 +99,12 @@ pub mod tests { use crate::matmul::{ cmma::base::{make_accumulators, SharedMemoriesExpand}, test_utils::{ - assert_equals, assert_equals_range, cmma_available, create_empty, range_tensor_f16, + assert_equals, cmma_available, create_empty, range_tensor_f16, }, }; use super::*; - #[cube] - pub fn cmma_computation( - lhs: &Slice, - rhs: &Slice, - out: &mut SliceMut, - ) { - let a = cmma::Matrix::::new( - cmma::MatrixIdent::A, - 16, - 16, - 16, - cmma::MatrixLayout::RowMajor, - ); - let b = cmma::Matrix::::new( - cmma::MatrixIdent::B, - 16, - 16, - 16, - cmma::MatrixLayout::RowMajor, - ); - let c = cmma::Matrix::::new( - cmma::MatrixIdent::Accumulator, - 16, - 16, - 16, - cmma::MatrixLayout::Undefined, - ); - cmma::fill::(&c, F::new(0.0)); - - cmma::load::(&a, lhs, UInt::new(16)); - cmma::load::(&b, rhs, UInt::new(16)); - - cmma::execute::(&a, &b, &c, &c); - - cmma::store::(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor); - } - - #[cube(launch)] - fn compute_loop_cmma_test( - lhs: &Tensor, - rhs: &Tensor, - result: &mut Array, - ) { - cmma_computation(lhs.as_slice(), rhs.as_slice(), result.as_slice_mut()); - } - - #[cube(launch)] - fn compute_loop_cmma_offseted_slice_test( - lhs: &Tensor, - rhs: &Tensor, - result: &mut Array, - ) { - cmma_computation( - lhs.slice(256, 512), - rhs.slice(256, 512), - result.slice_mut(768, 1024), - ); - } - - #[cube(launch)] - fn compute_loop_cmma_offseted_slice_in_shared_memory_test( - lhs_tensor: &Tensor, - rhs_tensor: &Tensor, - accumulate_array: &mut Array, - m: Comptime, - k: Comptime, - n: Comptime, - ) { - let mut lhs = SharedMemory::::new(Comptime::get(m * k)); - let mut rhs = SharedMemory::::new(Comptime::get(k * n)); - let mut accumulate = SharedMemory::::new(Comptime::get(m * n)); - for i in range(0u32, Comptime::get(m * k), Comptime::new(false)) { - lhs[i] = lhs_tensor[i]; - } - for i in range(0u32, Comptime::get(k * n), Comptime::new(false)) { - rhs[i] = rhs_tensor[i]; - } - for i in range(0u32, Comptime::get(m * n), Comptime::new(false)) { - accumulate[i] = F::new(0.); - } - - cmma_computation( - lhs.slice(256, 512), - rhs.slice(256, 512), - accumulate.slice_mut(768, 1024), - ); - - for i in range(0u32, Comptime::get(m * n), Comptime::new(false)) { - accumulate_array[i] = accumulate[i]; - } - } - #[cube(launch)] fn compute_loop_test( lhs_tensor: &Tensor, @@ -243,183 +151,6 @@ pub mod tests { ); } - /// Exported test - pub fn cmma_warp_test(device: &R::Device) { - if !cmma_available::(device) { - // We can't execute the test, skip. - return; - } - - let lhs = range_tensor_f16::(16, 16, device); - let rhs = range_tensor_f16::(16, 16, device); - let results = create_empty::(16, 16, device); - let cube_dim = CubeDim::new(32, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - compute_loop_cmma_test::launch::( - R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, 256), - ); - - let expected = &[ - 19840., 19960., 20080., 20200., 20320., 20440., 20560., 20680., 20800., 20920., 21040., - 21160., 21280., 21400., 21520., 21640., 50560., 50936., 51312., 51688., 52064., 52440., - 52816., 53192., 53568., 53944., 54320., 54696., 55072., 55448., 55824., 56200., 81280., - 81912., 82544., 83176., 83808., 84440., 85072., 85704., 86336., 86968., 87600., 88232., - 88864., 89496., 90128., 90760., 112000., 112888., 113776., 114664., 115552., 116440., - 117328., 118216., 119104., 119992., 120880., 121768., 122656., 123544., 124432., - 125320., 142720., 143864., 145008., 146152., 147296., 148440., 149584., 150728., - 151872., 153016., 154160., 155304., 156448., 157592., 158736., 159880., 173440., - 174840., 176240., 177640., 179040., 180440., 181840., 183240., 184640., 186040., - 187440., 188840., 190240., 191640., 193040., 194440., 204160., 205816., 207472., - 209128., 210784., 212440., 214096., 215752., 217408., 219064., 220720., 222376., - 224032., 225688., 227344., 229000., 234880., 236792., 238704., 240616., 242528., - 244440., 246352., 248264., 250176., 252088., 254000., 255912., 257824., 259736., - 261648., 263560., 265600., 267768., 269936., 272104., 274272., 276440., 278608., - 280776., 282944., 285112., 287280., 289448., 291616., 293784., 295952., 298120., - 296320., 298744., 301168., 303592., 306016., 308440., 310864., 313288., 315712., - 318136., 320560., 322984., 325408., 327832., 330256., 332680., 327040., 329720., - 332400., 335080., 337760., 340440., 343120., 345800., 348480., 351160., 353840., - 356520., 359200., 361880., 364560., 367240., 357760., 360696., 363632., 366568., - 369504., 372440., 375376., 378312., 381248., 384184., 387120., 390056., 392992., - 395928., 398864., 401800., 388480., 391672., 394864., 398056., 401248., 404440., - 407632., 410824., 414016., 417208., 420400., 423592., 426784., 429976., 433168., - 436360., 419200., 422648., 426096., 429544., 432992., 436440., 439888., 443336., - 446784., 450232., 453680., 457128., 460576., 464024., 467472., 470920., 449920., - 453624., 457328., 461032., 464736., 468440., 472144., 475848., 479552., 483256., - 486960., 490664., 494368., 498072., 501776., 505480., 480640., 484600., 488560., - 492520., 496480., 500440., 504400., 508360., 512320., 516280., 520240., 524200., - 528160., 532120., 536080., 540040., - ]; - - assert_equals::(results, expected, device); - } - - /// Exported test - pub fn compute_loop_cmma_offseted_warp_test(device: &R::Device) { - if !cmma_available::(device) { - // We can't execute the test, skip. - return; - } - - let lhs = range_tensor_f16::(32, 16, device); - let rhs = range_tensor_f16::(16, 32, device); - let results = create_empty::(32, 32, device); - let cube_dim = CubeDim::new(32, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - compute_loop_cmma_offseted_slice_test::launch::( - R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, 256), - ); - - let expected = &[ - 1590656.0, 1594872.0, 1599088.0, 1603304.0, 1607520.0, 1611736.0, 1615952.0, 1620168.0, - 1624384.0, 1628600.0, 1632816.0, 1637032.0, 1641248.0, 1645464.0, 1649680.0, 1653896.0, - 1686912.0, 1691384.0, 1695856.0, 1700328.0, 1704800.0, 1709272.0, 1713744.0, 1718216.0, - 1722688.0, 1727160.0, 1731632.0, 1736104.0, 1740576.0, 1745048.0, 1749520.0, 1753992.0, - 1783168.0, 1787896.0, 1792624.0, 1797352.0, 1802080.0, 1806808.0, 1811536.0, 1816264.0, - 1820992.0, 1825720.0, 1830448.0, 1835176.0, 1839904.0, 1844632.0, 1849360.0, 1854088.0, - 1879424.0, 1884408.0, 1889392.0, 1894376.0, 1899360.0, 1904344.0, 1909328.0, 1914312.0, - 1919296.0, 1924280.0, 1929264.0, 1934248.0, 1939232.0, 1944216.0, 1949200.0, 1954184.0, - 1975680.0, 1980920.0, 1986160.0, 1991400.0, 1996640.0, 2001880.0, 2007120.0, 2012360.0, - 2017600.0, 2022840.0, 2028080.0, 2033320.0, 2038560.0, 2043800.0, 2049040.0, 2054280.0, - 2071936.0, 2077432.0, 2082928.0, 2088424.0, 2093920.0, 2099416.0, 2104912.0, 2110408.0, - 2115904.0, 2121400.0, 2126896.0, 2132392.0, 2137888.0, 2143384.0, 2148880.0, 2154376.0, - 2168192.0, 2173944.0, 2179696.0, 2185448.0, 2191200.0, 2196952.0, 2202704.0, 2208456.0, - 2214208.0, 2219960.0, 2225712.0, 2231464.0, 2237216.0, 2242968.0, 2248720.0, 2254472.0, - 2264448.0, 2270456.0, 2276464.0, 2282472.0, 2288480.0, 2294488.0, 2300496.0, 2306504.0, - 2312512.0, 2318520.0, 2324528.0, 2330536.0, 2336544.0, 2342552.0, 2348560.0, 2354568.0, - 2360704.0, 2366968.0, 2373232.0, 2379496.0, 2385760.0, 2392024.0, 2398288.0, 2404552.0, - 2410816.0, 2417080.0, 2423344.0, 2429608.0, 2435872.0, 2442136.0, 2448400.0, 2454664.0, - 2456960.0, 2463480.0, 2470000.0, 2476520.0, 2483040.0, 2489560.0, 2496080.0, 2502600.0, - 2509120.0, 2515640.0, 2522160.0, 2528680.0, 2535200.0, 2541720.0, 2548240.0, 2554760.0, - 2553216.0, 2559992.0, 2566768.0, 2573544.0, 2580320.0, 2587096.0, 2593872.0, 2600648.0, - 2607424.0, 2614200.0, 2620976.0, 2627752.0, 2634528.0, 2641304.0, 2648080.0, 2654856.0, - 2649472.0, 2656504.0, 2663536.0, 2670568.0, 2677600.0, 2684632.0, 2691664.0, 2698696.0, - 2705728.0, 2712760.0, 2719792.0, 2726824.0, 2733856.0, 2740888.0, 2747920.0, 2754952.0, - 2745728.0, 2753016.0, 2760304.0, 2767592.0, 2774880.0, 2782168.0, 2789456.0, 2796744.0, - 2804032.0, 2811320.0, 2818608.0, 2825896.0, 2833184.0, 2840472.0, 2847760.0, 2855048.0, - 2841984.0, 2849528.0, 2857072.0, 2864616.0, 2872160.0, 2879704.0, 2887248.0, 2894792.0, - 2902336.0, 2909880.0, 2917424.0, 2924968.0, 2932512.0, 2940056.0, 2947600.0, 2955144.0, - 2938240.0, 2946040.0, 2953840.0, 2961640.0, 2969440.0, 2977240.0, 2985040.0, 2992840.0, - 3000640.0, 3008440.0, 3016240.0, 3024040.0, 3031840.0, 3039640.0, 3047440.0, 3055240.0, - 3034496.0, 3042552.0, 3050608.0, 3058664.0, 3066720.0, 3074776.0, 3082832.0, 3090888.0, - 3098944.0, 3107000.0, 3115056.0, 3123112.0, 3131168.0, 3139224.0, 3147280.0, 3155336.0, - ]; - assert_equals_range::(results, expected, 768..1024, device); - } - - /// Exported test - pub fn compute_loop_cmma_offseted_warp_in_shared_memory_test(device: &R::Device) { - if !cmma_available::(device) { - // We can't execute the test, skip. - return; - } - - let lhs = range_tensor_f16::(32, 16, device); - let rhs = range_tensor_f16::(16, 32, device); - let results = create_empty::(32, 32, device); - let cube_dim = CubeDim::new(32, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - compute_loop_cmma_offseted_slice_in_shared_memory_test::launch::( - R::client(device), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, 256), - UInt::new(32), - UInt::new(16), - UInt::new(32), - ); - - let expected = &[ - 1590656.0, 1594872.0, 1599088.0, 1603304.0, 1607520.0, 1611736.0, 1615952.0, 1620168.0, - 1624384.0, 1628600.0, 1632816.0, 1637032.0, 1641248.0, 1645464.0, 1649680.0, 1653896.0, - 1686912.0, 1691384.0, 1695856.0, 1700328.0, 1704800.0, 1709272.0, 1713744.0, 1718216.0, - 1722688.0, 1727160.0, 1731632.0, 1736104.0, 1740576.0, 1745048.0, 1749520.0, 1753992.0, - 1783168.0, 1787896.0, 1792624.0, 1797352.0, 1802080.0, 1806808.0, 1811536.0, 1816264.0, - 1820992.0, 1825720.0, 1830448.0, 1835176.0, 1839904.0, 1844632.0, 1849360.0, 1854088.0, - 1879424.0, 1884408.0, 1889392.0, 1894376.0, 1899360.0, 1904344.0, 1909328.0, 1914312.0, - 1919296.0, 1924280.0, 1929264.0, 1934248.0, 1939232.0, 1944216.0, 1949200.0, 1954184.0, - 1975680.0, 1980920.0, 1986160.0, 1991400.0, 1996640.0, 2001880.0, 2007120.0, 2012360.0, - 2017600.0, 2022840.0, 2028080.0, 2033320.0, 2038560.0, 2043800.0, 2049040.0, 2054280.0, - 2071936.0, 2077432.0, 2082928.0, 2088424.0, 2093920.0, 2099416.0, 2104912.0, 2110408.0, - 2115904.0, 2121400.0, 2126896.0, 2132392.0, 2137888.0, 2143384.0, 2148880.0, 2154376.0, - 2168192.0, 2173944.0, 2179696.0, 2185448.0, 2191200.0, 2196952.0, 2202704.0, 2208456.0, - 2214208.0, 2219960.0, 2225712.0, 2231464.0, 2237216.0, 2242968.0, 2248720.0, 2254472.0, - 2264448.0, 2270456.0, 2276464.0, 2282472.0, 2288480.0, 2294488.0, 2300496.0, 2306504.0, - 2312512.0, 2318520.0, 2324528.0, 2330536.0, 2336544.0, 2342552.0, 2348560.0, 2354568.0, - 2360704.0, 2366968.0, 2373232.0, 2379496.0, 2385760.0, 2392024.0, 2398288.0, 2404552.0, - 2410816.0, 2417080.0, 2423344.0, 2429608.0, 2435872.0, 2442136.0, 2448400.0, 2454664.0, - 2456960.0, 2463480.0, 2470000.0, 2476520.0, 2483040.0, 2489560.0, 2496080.0, 2502600.0, - 2509120.0, 2515640.0, 2522160.0, 2528680.0, 2535200.0, 2541720.0, 2548240.0, 2554760.0, - 2553216.0, 2559992.0, 2566768.0, 2573544.0, 2580320.0, 2587096.0, 2593872.0, 2600648.0, - 2607424.0, 2614200.0, 2620976.0, 2627752.0, 2634528.0, 2641304.0, 2648080.0, 2654856.0, - 2649472.0, 2656504.0, 2663536.0, 2670568.0, 2677600.0, 2684632.0, 2691664.0, 2698696.0, - 2705728.0, 2712760.0, 2719792.0, 2726824.0, 2733856.0, 2740888.0, 2747920.0, 2754952.0, - 2745728.0, 2753016.0, 2760304.0, 2767592.0, 2774880.0, 2782168.0, 2789456.0, 2796744.0, - 2804032.0, 2811320.0, 2818608.0, 2825896.0, 2833184.0, 2840472.0, 2847760.0, 2855048.0, - 2841984.0, 2849528.0, 2857072.0, 2864616.0, 2872160.0, 2879704.0, 2887248.0, 2894792.0, - 2902336.0, 2909880.0, 2917424.0, 2924968.0, 2932512.0, 2940056.0, 2947600.0, 2955144.0, - 2938240.0, 2946040.0, 2953840.0, 2961640.0, 2969440.0, 2977240.0, 2985040.0, 2992840.0, - 3000640.0, 3008440.0, 3016240.0, 3024040.0, 3031840.0, 3039640.0, 3047440.0, 3055240.0, - 3034496.0, 3042552.0, 3050608.0, 3058664.0, 3066720.0, 3074776.0, 3082832.0, 3090888.0, - 3098944.0, 3107000.0, 3115056.0, 3123112.0, 3131168.0, 3139224.0, 3147280.0, 3155336.0, - ]; - assert_equals_range::(results, expected, 768..1024, device); - } - /// Exported test pub fn compute_loop_k_test(device: &R::Device) { if !cmma_available::(device) { diff --git a/crates/cubecl-lac/src/tests/matmul_internal.rs b/crates/cubecl-lac/src/tests/matmul_internal.rs index 87eb00a59..a7fdc9d97 100644 --- a/crates/cubecl-lac/src/tests/matmul_internal.rs +++ b/crates/cubecl-lac/src/tests/matmul_internal.rs @@ -152,25 +152,6 @@ macro_rules! testgen_matmul_internal { cmma_compute_loop_tests::compute_loop_warp_test::(&Default::default()) } - #[test] - pub fn cmma_compute_loop_cmma_offseted_warp_test() { - cmma_compute_loop_tests::compute_loop_cmma_offseted_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_compute_loop_cmma_offseted_warp_in_shared_memory_test() { - cmma_compute_loop_tests::compute_loop_cmma_offseted_warp_in_shared_memory_test::< - TestRuntime, - >(&Default::default()) - } - - #[test] - pub fn cmma_warp_test() { - cmma_compute_loop_tests::cmma_warp_test::(&Default::default()) - } - #[test] pub fn cmma_load_shared_memory_lhs_unit_test() { cmma_load_shared_memory_tests::load_shared_memory_lhs_unit_test::(