Skip to content

Commit

Permalink
Add tests for syr2k
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Oct 18, 2021
1 parent ea3e8bc commit 939ff0d
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,35 @@ GEMM_OP_SYRK(D, double);
GEMM_OP_SYRK(C, cuComplex);
GEMM_OP_SYRK(Z, cuDoubleComplex);

// -------------
// Syr2k
// -------------
template <class T>
cublasStatus_t syr2k(cublasHandle_t handle,
cublasFillMode_t uplo, cublasOperation_t trans,
int n, int k,
const T *alpha,
const T *A, int lda,
const T *B, int ldb,
const T *beta , T *C, int ldc
);
#define GEMM_OP_SYR2K(short_type, type)\
template <>\
cublasStatus_t syr2k<type>(cublasHandle_t handle, cublasFillMode_t uplo,\
cublasOperation_t trans,\
int n, int k,\
const type *alpha, \
const type *A, int lda,\
const type *B, int ldb,\
const type *beta, type *C, int ldc\
) {\
return cublas##short_type##syr2k(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc);\
}
GEMM_OP_SYR2K(S, float);
GEMM_OP_SYR2K(D, double);
GEMM_OP_SYR2K(C, cuComplex);
GEMM_OP_SYR2K(Z, cuDoubleComplex);

// -------------
// Symm
// -------------
Expand Down Expand Up @@ -407,6 +436,41 @@ void syrk_test() {
cudaFree(mat_c);
}

template <class T>
void syr2k_test() {
const std::size_t n = 1lu << 10;
const auto alpha = convert<T>(1);
const auto beta = convert<T>(0);

T* mat_a;
T* mat_b;
T* mat_c;

cudaMalloc(&mat_a, sizeof(T) * n * n);
cudaMalloc(&mat_b, sizeof(T) * n * n);
cudaMalloc(&mat_c, sizeof(T) * n * n);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

syr2k<T>(
cublas_handle,
CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_N,
n, n,
&alpha,
mat_a, n,
mat_b, n,
&beta,
mat_c, n
);

cublasDestroy(cublas_handle);
cudaFree(mat_a);
cudaFree(mat_b);
cudaFree(mat_c);
}

template <class T>
void symm_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -518,6 +582,11 @@ void test_all() {
symm_test<cuComplex >();
symm_test<cuDoubleComplex>();

syr2k_test<double >();
syr2k_test<float >();
syr2k_test<cuComplex >();
syr2k_test<cuDoubleComplex>();

gemm3m_test<cuComplex >();
gemm3m_test<cuDoubleComplex>();
}
Expand Down

0 comments on commit 939ff0d

Please sign in to comment.