From 939ff0d0e2952e557db38bd8566760e35a4c3bbe Mon Sep 17 00:00:00 2001 From: mutsuki Date: Mon, 18 Oct 2021 14:35:59 +0900 Subject: [PATCH] Add tests for syr2k --- tests/cublas_test.cu | 69 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/cublas_test.cu b/tests/cublas_test.cu index a1871bf..b831fc3 100644 --- a/tests/cublas_test.cu +++ b/tests/cublas_test.cu @@ -163,6 +163,35 @@ GEMM_OP_SYRK(D, double); GEMM_OP_SYRK(C, cuComplex); GEMM_OP_SYRK(Z, cuDoubleComplex); +// ------------- +// Syr2k +// ------------- +template +cublasStatus_t syr2k(cublasHandle_t handle, + cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, + const T *alpha, + const T *A, int lda, + const T *B, int ldb, + const T *beta , T *C, int ldc + ); +#define GEMM_OP_SYR2K(short_type, type)\ +template <>\ +cublasStatus_t syr2k(cublasHandle_t handle, cublasFillMode_t uplo,\ + cublasOperation_t trans,\ + int n, int k,\ + const type *alpha, \ + const type *A, int lda,\ + const type *B, int ldb,\ + const type *beta, type *C, int ldc\ + ) {\ + return cublas##short_type##syr2k(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc);\ +} +GEMM_OP_SYR2K(S, float); +GEMM_OP_SYR2K(D, double); +GEMM_OP_SYR2K(C, cuComplex); +GEMM_OP_SYR2K(Z, cuDoubleComplex); + // ------------- // Symm // ------------- @@ -407,6 +436,41 @@ void syrk_test() { cudaFree(mat_c); } +template +void syr2k_test() { + const std::size_t n = 1lu << 10; + const auto alpha = convert(1); + const auto beta = convert(0); + + T* mat_a; + T* mat_b; + T* mat_c; + + cudaMalloc(&mat_a, sizeof(T) * n * n); + cudaMalloc(&mat_b, sizeof(T) * n * n); + cudaMalloc(&mat_c, sizeof(T) * n * n); + + cublasHandle_t cublas_handle; + cublasCreate(&cublas_handle); + + syr2k( + cublas_handle, + CUBLAS_FILL_MODE_LOWER, + CUBLAS_OP_N, + n, n, + &alpha, + mat_a, n, + mat_b, n, + &beta, + mat_c, n + ); + + cublasDestroy(cublas_handle); + cudaFree(mat_a); + cudaFree(mat_b); + cudaFree(mat_c); +} + template void symm_test() { const std::size_t n = 1lu << 10; @@ -518,6 +582,11 @@ void test_all() { symm_test(); symm_test(); + syr2k_test(); + syr2k_test(); + syr2k_test(); + syr2k_test(); + gemm3m_test(); gemm3m_test(); }