Skip to content

Commit

Permalink
Add support for cublasGemmStridedBatchedEx
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Aug 25, 2022
1 parent 59dd8da commit 685889c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
37 changes: 28 additions & 9 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
namespace {
mtk::cu_exp_statistics::result_t exp_stats(
const void* const ptr,
const std::size_t offset,
const std::size_t m,
const std::size_t n,
const std::size_t ld,
Expand All @@ -25,22 +26,22 @@ mtk::cu_exp_statistics::result_t exp_stats(
mtk::cu_exp_statistics::result_t result;
switch (data_t) {
case CUDA_R_64F:
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const double*>(ptr), m, n, ld, cuda_stream);
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const double*>(ptr) + offset, m, n, ld, cuda_stream);
break;
case CUDA_R_32F:
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const float*>(ptr), m, n, ld, cuda_stream);
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const float*>(ptr) + offset, m, n, ld, cuda_stream);
break;
case CUDA_R_16F:
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const half*>(ptr), m, n, ld, cuda_stream);
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const half*>(ptr) + offset, m, n, ld, cuda_stream);
break;
case CUDA_C_64F:
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const double2*>(ptr), m, n, ld, cuda_stream);
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const double2*>(ptr) + offset, m, n, ld, cuda_stream);
break;
case CUDA_C_32F:
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const float2*>(ptr), m, n, ld, cuda_stream);
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const float2*>(ptr) + offset, m, n, ld, cuda_stream);
break;
case CUDA_C_16F:
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const half2*>(ptr), m, n, ld, cuda_stream);
result = mtk::cu_exp_statistics::take_matrix_statistics(reinterpret_cast<const half2*>(ptr) + offset, m, n, ld, cuda_stream);
break;
default:
break;
Expand Down Expand Up @@ -264,8 +265,8 @@ cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa,
CULiP_exp_stats b_stats;
snprintf(a_stats.name, a_stats.name_length - 1, "A");
snprintf(b_stats.name, b_stats.name_length - 1, "B");
a_stats.stats = exp_stats(A, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype);
b_stats.stats = exp_stats(B, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype);
a_stats.stats = exp_stats(A, 0, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype);
b_stats.stats = exp_stats(B, 0, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype);
mtk::cu_exp_statistics::to_json(a_stats.stats);
mtk::cu_exp_statistics::to_json(b_stats.stats);
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats);
Expand Down Expand Up @@ -456,7 +457,7 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo) {
const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasGemmBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);
const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasGemmStridedBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const void*, const void*, cudaDataType_t, int, long long int, const void*, cudaDataType_t, int, long long int, const void*, void*, cudaDataType_t, int, long long int, int, cublasComputeType_t, cublasGemmAlgo_t);
Expand Down Expand Up @@ -493,6 +494,24 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

const int exp_stats_flag = (CULiP_profiling_control_array[CULiP_cublasGemmStridedBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_EXP_STATS_ENABLE_ENV_NAME, false);
if (exp_stats_flag) {
cudaStream_t cuda_stream;
cublasGetStream(handle, &cuda_stream);
CULiP_exp_stats a_stats;
CULiP_exp_stats b_stats;
snprintf(a_stats.name, a_stats.name_length - 1, "A");
snprintf(b_stats.name, b_stats.name_length - 1, "B");
for (std::uint32_t i = 0; i < batchCount; i++) {
a_stats.stats += exp_stats(A, i * strideA, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype);
b_stats.stats += exp_stats(B, i * strideB, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype);
}
mtk::cu_exp_statistics::to_json(a_stats.stats);
mtk::cu_exp_statistics::to_json(b_stats.stats);
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats);
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&b_stats);
}

return result;
}

Expand Down
1 change: 0 additions & 1 deletion src/cublas.gemm_strided_batched.template.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,5 @@ cublasStatus_t CULIP_FUNC_NAME (cublasHandle_t handle,
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&b_stats);
}


return result;
}

0 comments on commit 685889c

Please sign in to comment.