From 38c28ddb0df99e7c1d7c8f9df13bbbf4822ada0e Mon Sep 17 00:00:00 2001 From: mutsuki Date: Tue, 26 Oct 2021 20:59:47 +0900 Subject: [PATCH] Add tests for strided batched gemm --- tests/cublas_test.cu | 91 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/tests/cublas_test.cu b/tests/cublas_test.cu index f1000cf..a325584 100644 --- a/tests/cublas_test.cu +++ b/tests/cublas_test.cu @@ -92,6 +92,50 @@ GEMM_BATCHED_OP_GEMMEX(CUDA_R_64F, double); GEMM_BATCHED_OP_GEMMEX(CUDA_C_32F, cuComplex); GEMM_BATCHED_OP_GEMMEX(CUDA_C_64F, cuDoubleComplex); +// ------------- +// GemmStridedBatched +// ------------- +template +cublasStatus_t gemm_strided_batched(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const T *alpha, const T *A, int lda, long long int strideA, + const T *B, int ldb, long long int strideB, const T *beta, T *C, + int ldc, long long int strideC, int batchCount); +// ----------------------------------------------------- +// op_gemm +// ----------------------------------------------------- +#define GEMM_STRIDED_BATCHED_OP_GEMM(short_type, type)\ +template <>\ +cublasStatus_t gemm_strided_batched(cublasHandle_t handle, cublasOperation_t transa,\ + cublasOperation_t transb, int m, int n, int k,\ + const type *alpha, const type *A, int lda, long long int strideA,\ + const type *B, int ldb, long long int strideB, const type *beta, type *C,\ + int ldc, long long int strideC, int batchCount) {\ + return cublas##short_type##gemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount);\ +} +GEMM_STRIDED_BATCHED_OP_GEMM(H, half); +GEMM_STRIDED_BATCHED_OP_GEMM(S, float); +GEMM_STRIDED_BATCHED_OP_GEMM(D, double); +GEMM_STRIDED_BATCHED_OP_GEMM(C, cuComplex); +GEMM_STRIDED_BATCHED_OP_GEMM(Z, cuDoubleComplex); +// ----------------------------------------------------- +// op_gemmEx +// ----------------------------------------------------- +#define GEMM_STRIDED_BATCHED_OP_GEMMEX(cuda_data_type, type)\ +template <>\ +cublasStatus_t gemm_strided_batched(cublasHandle_t handle, cublasOperation_t transa,\ + cublasOperation_t transb, int m, int n, int k,\ + const type *alpha, const type *A, int lda, long long int strideA,\ + const type *B, int ldb, long long int strideB, const type *beta, type *C,\ + int ldc, long long int strideC, int batchCount) {\ + return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, reinterpret_cast(A), cuda_data_type, lda, strideA, reinterpret_cast(B), cuda_data_type, ldb, strideB, beta, reinterpret_cast(C), cuda_data_type, ldc, strideC, batchCount, cuda_data_type, CUBLAS_GEMM_DEFAULT);\ +} +GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_R_16F, half); +GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_R_32F, float); +GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_R_64F, double); +GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_C_32F, cuComplex); +GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_C_64F, cuDoubleComplex); + // ------------- // Gemv // ------------- @@ -426,6 +470,42 @@ void gemm_batched_test() { cudaFreeHost(mat_c_array); } +template +void gemm_strided_batched_test() { + const int n = 1lu << 7; + const int batch_size = 1u << 10; + const auto alpha = convert(1); + const auto beta = convert(0); + + T* mat_a; + T* mat_b; + T* mat_c; + + cudaMallocHost(&mat_a, sizeof(T) * n * n * batch_size); + cudaMallocHost(&mat_b, sizeof(T) * n * n * batch_size); + cudaMallocHost(&mat_c, sizeof(T) * n * n * batch_size); + + cublasHandle_t cublas_handle; + cublasCreate(&cublas_handle); + + gemm_strided_batched( + cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, n, n, + &alpha, + mat_a, n, n * n, + mat_b, n, n * n, + &beta, + mat_c, n, n * n, + batch_size + ); + + cublasDestroy(cublas_handle); + cudaFreeHost(mat_a); + cudaFreeHost(mat_b); + cudaFreeHost(mat_c); +} + template void gemv_test() { const std::size_t n = 1lu << 10; @@ -748,6 +828,17 @@ void test_all() { gemm_batched_test(); gemm_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemm_strided_batched_test(); + gemv_test(); gemv_test(); gemv_test();