From 685889c6a93eea1a8d787db66b18265edd585cdd Mon Sep 17 00:00:00 2001 From: mutsuki Date: Thu, 25 Aug 2022 21:06:33 +0900 Subject: [PATCH] Add support for `cublasGemmStridedBatchedEx` --- src/cublas.cu | 37 ++++++++++++++++------ src/cublas.gemm_strided_batched.template.h | 1 - 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/cublas.cu b/src/cublas.cu index f155026..e96ff0b 100644 --- a/src/cublas.cu +++ b/src/cublas.cu @@ -16,6 +16,7 @@ namespace { mtk::cu_exp_statistics::result_t exp_stats( const void* const ptr, + const std::size_t offset, const std::size_t m, const std::size_t n, const std::size_t ld, @@ -25,22 +26,22 @@ mtk::cu_exp_statistics::result_t exp_stats( mtk::cu_exp_statistics::result_t result; switch (data_t) { case CUDA_R_64F: - result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr), m, n, ld, cuda_stream); + result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr) + offset, m, n, ld, cuda_stream); break; case CUDA_R_32F: - result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr), m, n, ld, cuda_stream); + result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr) + offset, m, n, ld, cuda_stream); break; case CUDA_R_16F: - result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr), m, n, ld, cuda_stream); + result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr) + offset, m, n, ld, cuda_stream); break; case CUDA_C_64F: - result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr), m, n, ld, cuda_stream); + result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr) + offset, m, n, ld, cuda_stream); break; case CUDA_C_32F: - result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr), m, n, ld, cuda_stream); + result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr) + offset, m, n, ld, cuda_stream); break; case CUDA_C_16F: - result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr), m, n, ld, cuda_stream); + result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast(ptr) + offset, m, n, ld, cuda_stream); break; default: break; @@ -264,8 +265,8 @@ cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, CULiP_exp_stats b_stats; snprintf(a_stats.name, a_stats.name_length - 1, "A"); snprintf(b_stats.name, b_stats.name_length - 1, "B"); - a_stats.stats = exp_stats(A, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype); - b_stats.stats = exp_stats(B, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype); + a_stats.stats = exp_stats(A, 0, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype); + b_stats.stats = exp_stats(B, 0, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype); mtk::cu_exp_statistics::to_json(a_stats.stats); mtk::cu_exp_statistics::to_json(b_stats.stats); CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats); @@ -456,7 +457,7 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, int batchCount, cublasComputeType_t computeType, cublasGemmAlgo_t algo) { - const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasGemmBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME); + const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasGemmStridedBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME); // Get the function pointer cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const void*, const void*, cudaDataType_t, int, long long int, const void*, cudaDataType_t, int, long long int, const void*, void*, cudaDataType_t, int, long long int, int, cublasComputeType_t, cublasGemmAlgo_t); @@ -493,6 +494,24 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result); } + const int exp_stats_flag = (CULiP_profiling_control_array[CULiP_cublasGemmStridedBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_EXP_STATS_ENABLE_ENV_NAME, false); + if (exp_stats_flag) { + cudaStream_t cuda_stream; + cublasGetStream(handle, &cuda_stream); + CULiP_exp_stats a_stats; + CULiP_exp_stats b_stats; + snprintf(a_stats.name, a_stats.name_length - 1, "A"); + snprintf(b_stats.name, b_stats.name_length - 1, "B"); + for (std::uint32_t i = 0; i < batchCount; i++) { + a_stats.stats += exp_stats(A, i * strideA, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype); + b_stats.stats += exp_stats(B, i * strideB, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype); + } + mtk::cu_exp_statistics::to_json(a_stats.stats); + mtk::cu_exp_statistics::to_json(b_stats.stats); + CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats); + CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&b_stats); + } + return result; } diff --git a/src/cublas.gemm_strided_batched.template.h b/src/cublas.gemm_strided_batched.template.h index 207ec36..f6444b8 100644 --- a/src/cublas.gemm_strided_batched.template.h +++ b/src/cublas.gemm_strided_batched.template.h @@ -66,6 +66,5 @@ cublasStatus_t CULIP_FUNC_NAME (cublasHandle_t handle, CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&b_stats); } - return result; }