From 88582e727bf4afc2e47493fa3fe3b181c6c5c1b3 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Wed, 17 Jul 2024 17:40:20 +0800 Subject: [PATCH 1/4] add cublasGemmGroupedBatchedEx --- csrc/grouped_gemm.cu | 157 +++++++++++++++++++++++++++++++++++++++++++ setup.py | 3 + 2 files changed, 160 insertions(+) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 3729862..67d2b8e 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -206,6 +206,152 @@ void cublas_handle_init() } } +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500 + +#define MAX_GROUPSIZE 1024 + +cublasOperation_t transa_array[MAX_GROUPSIZE]; +cublasOperation_t transb_array[MAX_GROUPSIZE]; +int m_array[MAX_GROUPSIZE]; +int n_array[MAX_GROUPSIZE]; +int k_array[MAX_GROUPSIZE]; +float alpha_array[MAX_GROUPSIZE]; +float beta_array[MAX_GROUPSIZE]; + +void * Aarray[MAX_GROUPSIZE]; +int lda_array[MAX_GROUPSIZE]; +void * Barray[MAX_GROUPSIZE]; +int ldb_array[MAX_GROUPSIZE]; +void * Carray[MAX_GROUPSIZE]; +int ldc_array[MAX_GROUPSIZE]; + +// on device +void **d_Aarray = nullptr; +void **d_Barray = nullptr; +void **d_Carray = nullptr; + +int group_size[MAX_GROUPSIZE]; + +bool cublas_grouped_gemm_init = false; + +void cublas_grouped_gemm_global_var_init() +{ + cublas_grouped_gemm_init = true; + + for (int i = 0; i < MAX_GROUPSIZE; i++) + { + alpha_array[i] = 1.0; + beta_array[i] = 0.0; + group_size[i] = 1; + } + + CUDA_CALL(cudaMalloc(&d_Aarray, MAX_GROUPSIZE * sizeof(void *))); + CUDA_CALL(cudaMalloc(&d_Barray, MAX_GROUPSIZE * sizeof(void *))); + CUDA_CALL(cudaMalloc(&d_Carray, MAX_GROUPSIZE * sizeof(void *))); +} + +void CublasGemmGroupedBatched(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b) +{ + if (!cublas_grouped_gemm_init) + cublas_grouped_gemm_global_var_init(); + + int group_count = batch_sizes.size(0); + + c10::BFloat16* a_ptr = a.data_ptr(); + c10::BFloat16* b_ptr = b.data_ptr(); + c10::BFloat16* c_ptr = c.data_ptr(); + + int a_rows, a_cols, b_rows, b_cols, c_rows, c_cols; + + for (int i = 0; i < group_count; i++) + { + if (trans_a) { + a_rows = batch_sizes.data_ptr()[i]; + a_cols = a.size(1); + + // b.dims() == 2 here + b_rows = batch_sizes.data_ptr()[i]; + b_cols = b.size(1); + + c_rows = a_cols; + c_cols = b_cols; + } else { + a_rows = batch_sizes.data_ptr()[i]; + a_cols = a.size(1); + + // b.dims() == 3 here + b_rows = b.size(1); + b_cols = b.size(2); + + c_rows = a_rows; + c_cols = trans_b ? b_rows : b_cols; + } + + transa_array[i] = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; + transb_array[i] = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; + + int m = trans_b ? b_rows : b_cols; + int k = trans_b ? b_cols : b_rows; + int n = trans_a ? a_cols : a_rows; + m_array[i] = m; + n_array[i] = n; + k_array[i] = k; + + lda_array[i] = trans_a ? n : k; + ldb_array[i] = trans_b ? k : m; + ldc_array[i] = c_cols; + + Aarray[i] = a_ptr; + Barray[i] = b_ptr; + Carray[i] = c_ptr; + + a_ptr += a_rows * a_cols; + b_ptr += b_rows * b_cols; + c_ptr += c_rows * c_cols; + } + + CUDA_CALL(cudaMemcpyAsync(d_Aarray, Aarray, + sizeof(void *) * group_count, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMemcpyAsync(d_Barray, Barray, + sizeof(void *) * group_count, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMemcpyAsync(d_Carray, Carray, + sizeof(void *) * group_count, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + + CUBLAS_CALL(cublasGemmGroupedBatchedEx( + at::cuda::getCurrentCUDABlasHandle(), + transb_array, + transa_array, + m_array, + n_array, + k_array, + alpha_array, + d_Barray, + CUDA_R_16BF, + ldb_array, + d_Aarray, + CUDA_R_16BF, + lda_array, + beta_array, + d_Carray, + CUDA_R_16BF, + ldc_array, + group_count, + group_size, + CUBLAS_COMPUTE_32F)); +} + +#endif + inline void cublas_current_wait_streams(cudaStream_t stream) { for (int s = 0; s < NUM_STREAM; s++) @@ -259,6 +405,12 @@ void CublasGroupedGemm(torch::Tensor a, torch::Tensor c, torch::Tensor batch_sizes, bool trans_b) { + +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500 + CublasGemmGroupedBatched(a, b, c, batch_sizes, false, trans_b); + return; +#endif + if (!cublas_init) cublas_handle_init(); @@ -289,6 +441,11 @@ void CublasGroupedGemmVariableK(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes) { +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500 + CublasGemmGroupedBatched(a, b, c, batch_sizes, true, false); + return; +#endif + if (!cublas_init) cublas_handle_init(); diff --git a/setup.py b/setup.py index 8798172..36d1acf 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,9 @@ f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", ]) +if "CUBLAS_VERSION" in os.environ: + nvcc_flags.append(f"-DCUBLAS_VERSION={os.environ['CUBLAS_VERSION']}") + ext_modules = [ CUDAExtension( "grouped_gemm_backend", From eec3111b878fcaf8632941241f46ae057658d16d Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 18 Jul 2024 11:52:06 +0800 Subject: [PATCH 2/4] use shared trans_array --- csrc/grouped_gemm.cu | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 67d2b8e..c5dd6a4 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -210,8 +210,8 @@ void cublas_handle_init() #define MAX_GROUPSIZE 1024 -cublasOperation_t transa_array[MAX_GROUPSIZE]; -cublasOperation_t transb_array[MAX_GROUPSIZE]; +cublasOperation_t trans_array_T[MAX_GROUPSIZE]; +cublasOperation_t trans_array_N[MAX_GROUPSIZE]; int m_array[MAX_GROUPSIZE]; int n_array[MAX_GROUPSIZE]; int k_array[MAX_GROUPSIZE]; @@ -243,6 +243,8 @@ void cublas_grouped_gemm_global_var_init() alpha_array[i] = 1.0; beta_array[i] = 0.0; group_size[i] = 1; + trans_array_T[i] = CUBLAS_OP_T; + trans_array_N[i] = CUBLAS_OP_N; } CUDA_CALL(cudaMalloc(&d_Aarray, MAX_GROUPSIZE * sizeof(void *))); @@ -291,9 +293,6 @@ void CublasGemmGroupedBatched(torch::Tensor a, c_cols = trans_b ? b_rows : b_cols; } - transa_array[i] = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - transb_array[i] = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; - int m = trans_b ? b_rows : b_cols; int k = trans_b ? b_cols : b_rows; int n = trans_a ? a_cols : a_rows; @@ -329,8 +328,8 @@ void CublasGemmGroupedBatched(torch::Tensor a, CUBLAS_CALL(cublasGemmGroupedBatchedEx( at::cuda::getCurrentCUDABlasHandle(), - transb_array, - transa_array, + trans_b ? trans_array_T : trans_array_N, + trans_a ? trans_array_T : trans_array_N, m_array, n_array, k_array, From f242b12a3d6084bea33a6ee32df44fe9b23201b4 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 18 Jul 2024 11:56:53 +0800 Subject: [PATCH 3/4] use cudaMallocAsync --- csrc/grouped_gemm.cu | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index c5dd6a4..841b802 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -247,9 +247,18 @@ void cublas_grouped_gemm_global_var_init() trans_array_N[i] = CUBLAS_OP_N; } - CUDA_CALL(cudaMalloc(&d_Aarray, MAX_GROUPSIZE * sizeof(void *))); - CUDA_CALL(cudaMalloc(&d_Barray, MAX_GROUPSIZE * sizeof(void *))); - CUDA_CALL(cudaMalloc(&d_Carray, MAX_GROUPSIZE * sizeof(void *))); + CUDA_CALL(cudaMallocAsync( + &d_Aarray, + MAX_GROUPSIZE * sizeof(void *), + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMallocAsync( + &d_Barray, + MAX_GROUPSIZE * sizeof(void *), + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMallocAsync( + &d_Carray, + MAX_GROUPSIZE * sizeof(void *), + c10::cuda::getCurrentCUDAStream())); } void CublasGemmGroupedBatched(torch::Tensor a, From 68deba4fdf01302fd18a7fe9ef28dd2d5245caaf Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 18 Jul 2024 13:10:21 +0800 Subject: [PATCH 4/4] skip 0 size matmul --- csrc/grouped_gemm.cu | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 841b802..6a2af86 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -270,28 +270,28 @@ void CublasGemmGroupedBatched(torch::Tensor a, if (!cublas_grouped_gemm_init) cublas_grouped_gemm_global_var_init(); - int group_count = batch_sizes.size(0); - c10::BFloat16* a_ptr = a.data_ptr(); c10::BFloat16* b_ptr = b.data_ptr(); c10::BFloat16* c_ptr = c.data_ptr(); int a_rows, a_cols, b_rows, b_cols, c_rows, c_cols; - for (int i = 0; i < group_count; i++) + int group_count = 0; + for (int i = 0; i < batch_sizes.size(0); i++) { + int bs = batch_sizes.data_ptr()[i]; if (trans_a) { - a_rows = batch_sizes.data_ptr()[i]; + a_rows = bs; a_cols = a.size(1); // b.dims() == 2 here - b_rows = batch_sizes.data_ptr()[i]; + b_rows = bs; b_cols = b.size(1); c_rows = a_cols; c_cols = b_cols; } else { - a_rows = batch_sizes.data_ptr()[i]; + a_rows = bs; a_cols = a.size(1); // b.dims() == 3 here @@ -302,20 +302,24 @@ void CublasGemmGroupedBatched(torch::Tensor a, c_cols = trans_b ? b_rows : b_cols; } - int m = trans_b ? b_rows : b_cols; - int k = trans_b ? b_cols : b_rows; - int n = trans_a ? a_cols : a_rows; - m_array[i] = m; - n_array[i] = n; - k_array[i] = k; - - lda_array[i] = trans_a ? n : k; - ldb_array[i] = trans_b ? k : m; - ldc_array[i] = c_cols; - - Aarray[i] = a_ptr; - Barray[i] = b_ptr; - Carray[i] = c_ptr; + if (bs != 0) { + int m = trans_b ? b_rows : b_cols; + int k = trans_b ? b_cols : b_rows; + int n = trans_a ? a_cols : a_rows; + m_array[group_count] = m; + n_array[group_count] = n; + k_array[group_count] = k; + + lda_array[group_count] = trans_a ? n : k; + ldb_array[group_count] = trans_b ? k : m; + ldc_array[group_count] = c_cols; + + Aarray[group_count] = a_ptr; + Barray[group_count] = b_ptr; + Carray[group_count] = c_ptr; + + group_count++; + } a_ptr += a_rows * a_cols; b_ptr += b_rows * b_cols;