From 45344997893a659c00148a26a1963fdb98f2b607 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 21 Nov 2024 12:59:59 +0000 Subject: [PATCH] rebase get_smem api --- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 83 ++++++++++--------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index c765b3ce9d..5a612b05c2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -90,47 +90,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy return b_lds_block_desc; } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() - { - constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * - MakeALdsBlockDescriptor().get_element_space_size(); - return smem_size_a; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() - { - constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * - MakeBLdsBlockDescriptor().get_element_space_size(); - return smem_size_b; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size_a = GetSmemSizeA(); - constexpr index_t smem_size_b = GetSmemSizeB(); - index_t smem_size = 0; - smem_size += smem_size_a + smem_size_b; - - return smem_size; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() - { - using ADataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(ADataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() - { - using BDataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(BDataType); - } #elif 1 // fake XOR template @@ -203,6 +162,48 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } #endif + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + index_t smem_size = 0; + smem_size += smem_size_a + smem_size_b; + + return smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using ADataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BDataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(BDataType); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() {