Skip to content

Commit

Permalink
Add tests for trsm
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Oct 22, 2021
1 parent 28a23cb commit 22cc384
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,35 @@ GEMM_OP_TRMM(D, double);
GEMM_OP_TRMM(C, cuComplex);
GEMM_OP_TRMM(Z, cuDoubleComplex);

// -------------
// Trsm
// -------------
template <class T>
cublasStatus_t trsm(cublasHandle_t handle,
cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag,
int m, int n,
const T *alpha,
const T *A, int lda,
T *B, int ldb
);
#define GEMM_OP_TRSM(short_type, type)\
template <>\
cublasStatus_t trsm<type>(cublasHandle_t handle, \
cublasSideMode_t side, cublasFillMode_t uplo, \
cublasOperation_t trans, cublasDiagType_t diag, \
int m, int n,\
const type *alpha, \
const type *A, int lda,\
type *B, int ldb\
) {\
return cublas##short_type##trsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);\
}
GEMM_OP_TRSM(S, float);
GEMM_OP_TRSM(D, double);
GEMM_OP_TRSM(C, cuComplex);
GEMM_OP_TRSM(Z, cuDoubleComplex);

// -------------
// Gemm3m
// -------------
Expand Down Expand Up @@ -633,6 +662,35 @@ void trmm_test() {
cudaFree(mat_c);
}

template <class T>
void trsm_test() {
const std::size_t n = 1lu << 10;
const auto alpha = convert<T>(1);

T* mat_a;
T* mat_b;

cudaMalloc(&mat_a, sizeof(T) * n * n);
cudaMalloc(&mat_b, sizeof(T) * n * n);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

trsm<T>(
cublas_handle,
CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT,
n, n,
&alpha,
mat_a, n,
mat_b, n
);

cublasDestroy(cublas_handle);
cudaFree(mat_a);
cudaFree(mat_b);
}

template <class T>
void gemm3m_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -725,6 +783,11 @@ void test_all() {
trmm_test<cuComplex >();
trmm_test<cuDoubleComplex>();

trsm_test<double >();
trsm_test<float >();
trsm_test<cuComplex >();
trsm_test<cuDoubleComplex>();

gemm3m_test<cuComplex >();
gemm3m_test<cuDoubleComplex>();
}
Expand Down

0 comments on commit 22cc384

Please sign in to comment.