From 210e6e30d85db02b8e27460c887e3fea625b340e Mon Sep 17 00:00:00 2001 From: mutsuki Date: Sun, 17 Oct 2021 18:34:59 +0900 Subject: [PATCH] Add tests for symm --- tests/cublas_test.cu | 68 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/cublas_test.cu b/tests/cublas_test.cu index fd06043..a1871bf 100644 --- a/tests/cublas_test.cu +++ b/tests/cublas_test.cu @@ -163,6 +163,35 @@ GEMM_OP_SYRK(D, double); GEMM_OP_SYRK(C, cuComplex); GEMM_OP_SYRK(Z, cuDoubleComplex); +// ------------- +// Symm +// ------------- +template +cublasStatus_t symm(cublasHandle_t handle, cublasSideMode_t size, + cublasFillMode_t uplo, + int m, int n, + const T *alpha, + const T *A, int lda, + const T *B, int ldb, + const T *beta , T *C, int ldc + ); +#define GEMM_OP_SYMM(short_type, type)\ +template <>\ +cublasStatus_t symm(cublasHandle_t handle, cublasSideMode_t side,\ + cublasFillMode_t uplo,\ + int m, int n,\ + const type *alpha, \ + const type *A, int lda,\ + const type *B, int ldb,\ + const type *beta, type *C, int ldc\ + ) {\ + return cublas##short_type##symm(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);\ +} +GEMM_OP_SYMM(S, float); +GEMM_OP_SYMM(D, double); +GEMM_OP_SYMM(C, cuComplex); +GEMM_OP_SYMM(Z, cuDoubleComplex); + // ------------- // Gemm3m // ------------- @@ -378,6 +407,40 @@ void syrk_test() { cudaFree(mat_c); } +template +void symm_test() { + const std::size_t n = 1lu << 10; + const auto alpha = convert(1); + const auto beta = convert(0); + + T* mat_a; + T* mat_b; + T* mat_c; + + cudaMalloc(&mat_a, sizeof(T) * n * n); + cudaMalloc(&mat_b, sizeof(T) * n * n); + cudaMalloc(&mat_c, sizeof(T) * n * n); + + cublasHandle_t cublas_handle; + cublasCreate(&cublas_handle); + + symm( + cublas_handle, + CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_LOWER, + n, n, + &alpha, + mat_a, n, + mat_b, n, + &beta, + mat_c, n + ); + + cublasDestroy(cublas_handle); + cudaFree(mat_a); + cudaFree(mat_b); + cudaFree(mat_c); +} + template void gemm3m_test() { const std::size_t n = 1lu << 10; @@ -450,6 +513,11 @@ void test_all() { syrk_test(); syrk_test(); + symm_test(); + symm_test(); + symm_test(); + symm_test(); + gemm3m_test(); gemm3m_test(); }