Skip to content

Commit

Permalink
Add tests for symm
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Oct 17, 2021
1 parent 6d9be7d commit 210e6e3
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,35 @@ GEMM_OP_SYRK(D, double);
GEMM_OP_SYRK(C, cuComplex);
GEMM_OP_SYRK(Z, cuDoubleComplex);

// -------------
// Symm
// -------------
template <class T>
cublasStatus_t symm(cublasHandle_t handle, cublasSideMode_t size,
cublasFillMode_t uplo,
int m, int n,
const T *alpha,
const T *A, int lda,
const T *B, int ldb,
const T *beta , T *C, int ldc
);
#define GEMM_OP_SYMM(short_type, type)\
template <>\
cublasStatus_t symm<type>(cublasHandle_t handle, cublasSideMode_t side,\
cublasFillMode_t uplo,\
int m, int n,\
const type *alpha, \
const type *A, int lda,\
const type *B, int ldb,\
const type *beta, type *C, int ldc\
) {\
return cublas##short_type##symm(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);\
}
GEMM_OP_SYMM(S, float);
GEMM_OP_SYMM(D, double);
GEMM_OP_SYMM(C, cuComplex);
GEMM_OP_SYMM(Z, cuDoubleComplex);

// -------------
// Gemm3m
// -------------
Expand Down Expand Up @@ -378,6 +407,40 @@ void syrk_test() {
cudaFree(mat_c);
}

template <class T>
void symm_test() {
const std::size_t n = 1lu << 10;
const auto alpha = convert<T>(1);
const auto beta = convert<T>(0);

T* mat_a;
T* mat_b;
T* mat_c;

cudaMalloc(&mat_a, sizeof(T) * n * n);
cudaMalloc(&mat_b, sizeof(T) * n * n);
cudaMalloc(&mat_c, sizeof(T) * n * n);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

symm<T>(
cublas_handle,
CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_LOWER,
n, n,
&alpha,
mat_a, n,
mat_b, n,
&beta,
mat_c, n
);

cublasDestroy(cublas_handle);
cudaFree(mat_a);
cudaFree(mat_b);
cudaFree(mat_c);
}

template <class T>
void gemm3m_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -450,6 +513,11 @@ void test_all() {
syrk_test<cuComplex >();
syrk_test<cuDoubleComplex>();

symm_test<double >();
symm_test<float >();
symm_test<cuComplex >();
symm_test<cuDoubleComplex>();

gemm3m_test<cuComplex >();
gemm3m_test<cuDoubleComplex>();
}
Expand Down

0 comments on commit 210e6e3

Please sign in to comment.