From 0fcbb25f70e50ec6224cf2d3dd5704581fbfb54b Mon Sep 17 00:00:00 2001 From: deepsek <166548550+deepsek@users.noreply.github.com> Date: Thu, 16 Jan 2025 23:31:15 -0500 Subject: [PATCH] fix: preprocessor directives logic error if/else (#1764) * fix: preprocessors logic error if/else * fix: added macros as preferred by CK team --- .../src/profile_grouped_gemm_fixed_nk.cpp | 104 ++++++++---------- 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/profiler/src/profile_grouped_gemm_fixed_nk.cpp b/profiler/src/profile_grouped_gemm_fixed_nk.cpp index e33d798504..093557e7f3 100644 --- a/profiler/src/profile_grouped_gemm_fixed_nk.cpp +++ b/profiler/src/profile_grouped_gemm_fixed_nk.cpp @@ -21,7 +21,6 @@ enum struct GemmDataType F16_F16_F16, // 1 F16_F8_F16, // 2 F16_I8_F16, // 3 - }; #define OP_NAME "grouped_gemm_fixed_nk" @@ -39,7 +38,6 @@ std::vector argToIntArray(char* input) { out.push_back(std::stoi(item)); } - return out; } @@ -83,14 +81,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) const auto StrideCs = argToIntArray(argv[13]); const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; - using F32 = float; - using F16 = ck::half_t; -#if defined(CK_ENABLE_FP8) - using F8 = ck::f8_t; -#endif - using BF16 = ck::bhalf_t; - using I8 = int8_t; - int n_warmup = 1; int n_iter = 10; if(argc == 17) @@ -99,13 +89,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_iter = std::stoi(argv[16]); } -#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) - if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -123,12 +112,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_warmup, n_iter); } - else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -146,14 +135,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_warmup, n_iter); } -#endif -#if defined(CK_ENABLE_FP16) - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) +#if defined(CK_ENABLE_FP8) + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -171,12 +159,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_warmup, n_iter); } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -195,13 +183,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_iter); } #endif -#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) - else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) +#if defined(CK_ENABLE_INT8) + else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -219,12 +207,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_warmup, n_iter); } - else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -238,18 +226,19 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch, + 1, n_warmup, n_iter); } #endif -#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) - else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN) +#if defined(CK_ENABLE_BF16) +#if defined(CK_ENABLE_INT8) + else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -267,12 +256,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) n_warmup, n_iter); } - else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_grouped_gemm_fixed_nk_impl( @@ -286,10 +275,11 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - 1, + kbatch, n_warmup, n_iter); } +#endif #endif else {