From 896f8b4ccf01ccaabbfcc89d4f4769d9441a2836 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 10 Jan 2025 11:57:03 +0000 Subject: [PATCH 1/3] add gemm_api and instances --- example/ck_tile/03_gemm/CMakeLists.txt | 27 +- example/ck_tile/03_gemm/gemm_basic.cpp | 31 +- example/ck_tile/03_gemm/gemm_basic.hpp | 67 ++- .../ck_tile/03_gemm/instances/gemm_api.cpp | 482 ++++++++++++++++++ ...universal_comp_bf16_bf16_bf16_km_kn_mn.cpp | 27 + ...universal_comp_bf16_bf16_bf16_km_nk_mn.cpp | 27 + ...universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp | 26 + ...universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp | 27 + ...mm_universal_comp_f16_f16_f16_km_kn_mn.cpp | 27 + ...mm_universal_comp_f16_f16_f16_km_nk_mn.cpp | 27 + ...mm_universal_comp_f16_f16_f16_mk_kn_mn.cpp | 26 + ...mm_universal_comp_f16_f16_f16_mk_nk_mn.cpp | 27 + .../gemm_universal_comp_instance_common.hpp | 206 ++++++++ ..._universal_mem_bf16_bf16_bf16_km_kn_mn.cpp | 27 + ..._universal_mem_bf16_bf16_bf16_km_nk_mn.cpp | 27 + ..._universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp | 26 + ..._universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp | 27 + ...emm_universal_mem_f16_f16_f16_km_kn_mn.cpp | 27 + ...emm_universal_mem_f16_f16_f16_km_nk_mn.cpp | 27 + ...emm_universal_mem_f16_f16_f16_mk_kn_mn.cpp | 26 + ...emm_universal_mem_f16_f16_f16_mk_nk_mn.cpp | 27 + .../gemm_universal_mem_instance_common.hpp | 206 ++++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 30 +- example/ck_tile/03_gemm/universal_gemm.cpp | 203 +------- 24 files changed, 1453 insertions(+), 227 deletions(-) create mode 100644 example/ck_tile/03_gemm/instances/gemm_api.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index bc3799f015..f682ef0ac9 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,27 @@ +# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) +# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) + +function (add_gemm_example TARGET_NAME MAIN_SRC) +message("adding ${TARGET_NAME}") +# not using add_example_executable() to add target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) +endforeach() + +target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + +set(COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) +endfunction(add_gemm_example TARGET_NAME MAIN_SRC) + +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_gemm_example(tile_example_gemm_universal universal_gemm.cpp ${INSTANCE_SRCS}) + add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 4c630375f4..71c508bd45 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,13 +9,10 @@ #include #include -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/host.hpp" #include "gemm_basic.hpp" template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -103,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; } +float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else + { + throw std::runtime_error("Wrong! Layouts not supported!\n"); + } +} + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 38c0a279db..e659890f4b 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -1,13 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include - -#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" template struct GemmBasicTypeConfig; @@ -51,6 +52,59 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +struct gemm_traits +{ + std::string data_type; + bool is_a_rowmajor; + bool is_b_rowmajor; + bool is_c_rowmajor; +}; + +template +struct gemm_traits_ +{ + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using CDataType = ck_tile::remove_cvref_t; + using ALayout = ck_tile::remove_cvref_t; + using BLayout = ck_tile::remove_cvref_t; + using CLayout = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t M_Tile = M_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t K_Tile = K_Tile_; + static constexpr ck_tile::index_t M_Warp = M_Warp_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t K_Warp = K_Warp_; + static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -75,4 +129,9 @@ auto create_args(int argc, char* argv[]) } // host API -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + +float gemm(const gemm_traits& traits, + const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/instances/gemm_api.cpp b/example/ck_tile/03_gemm/instances/gemm_api.cpp new file mode 100644 index 0000000000..05fba01bd2 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_api.cpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_basic.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using FP32 = float; +using FP16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; + +template +using trait_ = gemm_traits_; + +float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RR + std::cout << "fp16 comp\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RR + std::cout << "fp16 mem\n"; + return gemm_>(a, s); + } + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RC + std::cout << "fp16 comp RC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RC + std::cout << "fp16 mem RC\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CR + std::cout << "fp16 comp CR\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CR + std::cout << "fp16 mem CR\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CC + std::cout << "fp16 comp CC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CC + std::cout << "fp16 mem CC\n"; + return gemm_>(a, s); + } + } + else + { + throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); + } + } + else if(t.data_type.compare("bf16") == 0) + { + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RR + std::cout << "bf16 comp\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RR + std::cout << "bf16 mem\n"; + return gemm_>(a, s); + } + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RC + std::cout << "bf16 comp RC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RC + std::cout << "bf16 mem RC\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CR + std::cout << "bf16 comp CR\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CR + std::cout << "bf16 mem CR\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CC + std::cout << "bf16 comp CC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CC + std::cout << "bf16 mem CC\n"; + return gemm_>(a, s); + } + } + else + { + throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); + } + } + else + { + throw std::runtime_error("Wrong! DataTypes not supported!\n"); + } + + return 1.0f; +} diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp new file mode 100644 index 0000000000..121b676b1c --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp new file mode 100644 index 0000000000..29d856c001 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp new file mode 100644 index 0000000000..76138c42d6 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..130b3e2691 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 0000000000..43971b017f --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 0000000000..0ad95f8831 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 0000000000..6e2ec55c7c --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..4ad3ed8a98 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp new file mode 100644 index 0000000000..2bdd5fc380 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include "gemm_basic.hpp" + +using A = ck_tile::GemmHostArgs; +using S = ck_tile::stream_config; + +template +using trait_ = gemm_traits_; + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = + ck_tile::Default2DEpilogue>; + using GemmTraits = ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< + ck_tile::GemmPipelineProblem>; + + constexpr int kBlockPerCu = 1; + + const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + static_assert(BaseGemmPipeline::PrefetchStages > 3); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp new file mode 100644 index 0000000000..f340a27001 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp new file mode 100644 index 0000000000..24aec06e12 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp new file mode 100644 index 0000000000..6ff10bfda8 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..98fb82163d --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 0000000000..8a462bb8f8 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 0000000000..8e78d850af --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 0000000000..487dc07d69 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..50823e96cd --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp new file mode 100644 index 0000000000..b78efae272 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include "gemm_basic.hpp" + +using A = ck_tile::GemmHostArgs; +using S = ck_tile::stream_config; + +template +using trait_ = gemm_traits_; + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = + ck_tile::Default2DEpilogue>; + using GemmTraits = ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; + + constexpr int kBlockPerCu = 1; + + const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + static_assert(BaseGemmPipeline::PrefetchStages > 3); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 56d0348bd6..b6d9d4399d 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -28,8 +28,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + gemm_traits traits{DataTypeTraits{}.name, + std::is_same_v, + std::is_same_v, + std::is_same_v}; + + float ave_time = + gemm(traits, args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -210,9 +215,6 @@ int run_gemm_example(int argc, char* argv[]) if(!result) return -1; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); @@ -224,16 +226,14 @@ int run_gemm_example(int argc, char* argv[]) { return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not - // work. - // else if(a_layout == "C" && b_layout == "C") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - // } - // else if(a_layout == "C" && b_layout == "R") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - // } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 1a9e025a9b..418c390cc1 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -14,207 +14,6 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -#define CK_TILE_PIPELINE_COMPUTE 1 -#define CK_TILE_PIPELINE_MEMORY 2 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE -#endif - -template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) -{ -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#endif - - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // =============================================== - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - - using Traits = ck_tile::TileGemmTraits; -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< -#endif - ck_tile::GemmPipelineProblem>; - - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< -#endif - ck_tile::UniversalGemmPipelineProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - if(has_hot_loop) - { - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - } - else - { - // Tail number always Full - #PrefetchStages - if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - - return ave_time; -} - #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } From 3cad16c41bd89988a8c24a53c535baf58d2ac7c6 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 10 Jan 2025 12:15:49 +0000 Subject: [PATCH 2/3] minor fixes --- example/ck_tile/03_gemm/gemm_basic.cpp | 23 +++++++++++++++++++ example/ck_tile/03_gemm/gemm_basic.hpp | 23 ------------------- .../ck_tile/03_gemm/instances/gemm_api.cpp | 2 -- example/ck_tile/03_gemm/universal_gemm.cpp | 23 +++++++++++++++++++ 4 files changed, 46 insertions(+), 25 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 71c508bd45..987e8fa074 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -124,6 +124,29 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til } } +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index e659890f4b..bac2ed19d8 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -105,29 +105,6 @@ struct gemm_traits_ static constexpr bool kPadK = kPadK_; }; -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - // host API template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/instances/gemm_api.cpp b/example/ck_tile/03_gemm/instances/gemm_api.cpp index 05fba01bd2..d417b8a676 100644 --- a/example/ck_tile/03_gemm/instances/gemm_api.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_api.cpp @@ -477,6 +477,4 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: { throw std::runtime_error("Wrong! DataTypes not supported!\n"); } - - return 1.0f; } diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 418c390cc1..5fd71fd837 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -14,6 +14,29 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } From c2945b966d7dceeede66fb979f92157c0517cca4 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Sun, 19 Jan 2025 13:05:45 +0000 Subject: [PATCH 3/3] add reviewers comments --- example/ck_tile/03_gemm/CMakeLists.txt | 5 +- .../03_gemm/{gemm_basic.hpp => gemm.hpp} | 18 +- example/ck_tile/03_gemm/gemm_basic.cpp | 25 +- .../ck_tile/03_gemm/instances/gemm_api.cpp | 683 ++++++++---------- ...universal_comp_bf16_bf16_bf16_km_kn_mn.cpp | 22 +- ...universal_comp_bf16_bf16_bf16_km_nk_mn.cpp | 22 +- ...universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp | 22 +- ...universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp | 22 +- ...mm_universal_comp_f16_f16_f16_km_kn_mn.cpp | 22 +- ...mm_universal_comp_f16_f16_f16_km_nk_mn.cpp | 22 +- ...mm_universal_comp_f16_f16_f16_mk_kn_mn.cpp | 22 +- ...mm_universal_comp_f16_f16_f16_mk_nk_mn.cpp | 22 +- .../gemm_universal_comp_instance_common.hpp | 41 +- ..._universal_mem_bf16_bf16_bf16_km_kn_mn.cpp | 22 +- ..._universal_mem_bf16_bf16_bf16_km_nk_mn.cpp | 22 +- ..._universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp | 22 +- ..._universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp | 22 +- ...emm_universal_mem_f16_f16_f16_km_kn_mn.cpp | 22 +- ...emm_universal_mem_f16_f16_f16_km_nk_mn.cpp | 22 +- ...emm_universal_mem_f16_f16_f16_mk_kn_mn.cpp | 22 +- ...emm_universal_mem_f16_f16_f16_mk_nk_mn.cpp | 22 +- .../gemm_universal_mem_instance_common.hpp | 41 +- example/ck_tile/03_gemm/run_gemm_example.inc | 24 + example/ck_tile/03_gemm/universal_gemm.cpp | 25 +- 24 files changed, 397 insertions(+), 817 deletions(-) rename example/ck_tile/03_gemm/{gemm_basic.hpp => gemm.hpp} (86%) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index f682ef0ac9..4b939337dc 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,6 +1,3 @@ -# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) - function (add_gemm_example TARGET_NAME MAIN_SRC) message("adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have @@ -16,7 +13,7 @@ target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) set(COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) endfunction(add_gemm_example TARGET_NAME MAIN_SRC) diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm.hpp similarity index 86% rename from example/ck_tile/03_gemm/gemm_basic.hpp rename to example/ck_tile/03_gemm/gemm.hpp index bac2ed19d8..10a1934fc4 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm.hpp @@ -55,12 +55,13 @@ using CDataType = Types::CDataType; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +/** \brief Struct used for specifying desired gemm details*/ struct gemm_traits { - std::string data_type; - bool is_a_rowmajor; - bool is_b_rowmajor; - bool is_c_rowmajor; + std::string data_type; /** Tensors datatype, can be set to either fp16 or bf16*/ + bool is_a_rowmajor; /** Whether A matrix is rowmajor */ + bool is_b_rowmajor; /** Whether B matrix is rowmajor */ + bool is_c_rowmajor; /** Whether C matrix is rowmajor */ }; template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +/** + * \brief Invoke gemm function + * + * \param traits Gemm traits which are used for choosing best instance. + * \param args Runtime gemm host arguments. + * \param s Stream configuration. + * \return Time of execution. + */ float gemm(const gemm_traits& traits, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 987e8fa074..3b2664736f 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -9,7 +9,7 @@ #include #include -#include "gemm_basic.hpp" +#include "gemm.hpp" template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) @@ -124,29 +124,6 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til } } -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/instances/gemm_api.cpp b/example/ck_tile/03_gemm/instances/gemm_api.cpp index d417b8a676..bf61bddeb4 100644 --- a/example/ck_tile/03_gemm/instances/gemm_api.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_api.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "gemm_basic.hpp" +#include "gemm.hpp" using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -10,45 +10,6 @@ using FP32 = float; using FP16 = ck_tile::half_t; using BF16 = ck_tile::bf16_t; -template -using trait_ = gemm_traits_; - float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s) { if(t.data_type.compare("fp16") == 0) @@ -57,204 +18,188 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: { if(a.M > 512) { - // universal gemm compute bound RR - std::cout << "fp16 comp\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound RR - std::cout << "fp16 mem\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) { if(a.M > 512) { - // universal gemm compute bound RC - std::cout << "fp16 comp RC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound RC - std::cout << "fp16 mem RC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) { if(a.M > 512) { - // universal gemm compute bound CR - std::cout << "fp16 comp CR\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound CR - std::cout << "fp16 mem CR\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) { if(a.M > 512) { - // universal gemm compute bound CC - std::cout << "fp16 comp CC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound CC - std::cout << "fp16 mem CC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else @@ -268,204 +213,188 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: { if(a.M > 512) { - // universal gemm compute bound RR - std::cout << "bf16 comp\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound RR - std::cout << "bf16 mem\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) { if(a.M > 512) { - // universal gemm compute bound RC - std::cout << "bf16 comp RC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound RC - std::cout << "bf16 mem RC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) { if(a.M > 512) { - // universal gemm compute bound CR - std::cout << "bf16 comp CR\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound CR - std::cout << "bf16 mem CR\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) { if(a.M > 512) { - // universal gemm compute bound CC - std::cout << "bf16 comp CC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } else { - // universal gemm memory bound CC - std::cout << "bf16 mem CC\n"; - return gemm_>(a, s); + return gemm_>(a, s); } } else diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp index 121b676b1c..fc350a5fd6 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp index 29d856c001..eeb1c35132 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp index 76138c42d6..6c2fe38914 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp @@ -5,22 +5,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp index 130b3e2691..3aa33ca83f 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp index 43971b017f..ed695b3be9 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp index 0ad95f8831..cb975f33a0 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp index 6e2ec55c7c..bfc9fc6a97 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp @@ -5,22 +5,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp index 4ad3ed8a98..1be99be0b6 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp index 2bdd5fc380..62df845eaa 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp @@ -2,50 +2,11 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include -#include "gemm_basic.hpp" +#include "gemm.hpp" using A = ck_tile::GemmHostArgs; using S = ck_tile::stream_config; -template -using trait_ = gemm_traits_; - template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp index f340a27001..299924eb32 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp index 24aec06e12..d28ce6e637 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp index 6ff10bfda8..aa8a772eec 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp @@ -5,22 +5,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp index 98fb82163d..30871f99fa 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp index 8a462bb8f8..611de23784 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp index 8e78d850af..15b01460e3 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp index 487dc07d69..b9b6c9b263 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp @@ -5,22 +5,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp index 50823e96cd..ee0703edb5 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp @@ -6,22 +6,6 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -template float gemm_>(const A&, const S&); +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp index b78efae272..2d3f7ec512 100644 --- a/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp @@ -2,50 +2,11 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include -#include "gemm_basic.hpp" +#include "gemm.hpp" using A = ck_tile::GemmHostArgs; using S = ck_tile::stream_config; -template -using trait_ = gemm_traits_; - template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index b6d9d4399d..80339b826b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -2,6 +2,29 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, @@ -28,6 +51,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; + // TODO: Change datatypes in future to allow mixed precision gemms! gemm_traits traits{DataTypeTraits{}.name, std::is_same_v, std::is_same_v, diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eb2b263a08..cd6c1dbfb0 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -10,30 +10,7 @@ #include #include "ck_tile/host.hpp" -#include "gemm_basic.hpp" - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} +#include "gemm.hpp" #include "run_gemm_example.inc"