Skip to content

Commit

Permalink
Add tests for strided batched gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Oct 26, 2021
1 parent 535ae41 commit 38c28dd
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,50 @@ GEMM_BATCHED_OP_GEMMEX(CUDA_R_64F, double);
GEMM_BATCHED_OP_GEMMEX(CUDA_C_32F, cuComplex);
GEMM_BATCHED_OP_GEMMEX(CUDA_C_64F, cuDoubleComplex);

// -------------
// GemmStridedBatched
// -------------
template <class T, class Op>
cublasStatus_t gemm_strided_batched(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const T *alpha, const T *A, int lda, long long int strideA,
const T *B, int ldb, long long int strideB, const T *beta, T *C,
int ldc, long long int strideC, int batchCount);
// -----------------------------------------------------
// op_gemm
// -----------------------------------------------------
#define GEMM_STRIDED_BATCHED_OP_GEMM(short_type, type)\
template <>\
cublasStatus_t gemm_strided_batched<type , op_gemm>(cublasHandle_t handle, cublasOperation_t transa,\
cublasOperation_t transb, int m, int n, int k,\
const type *alpha, const type *A, int lda, long long int strideA,\
const type *B, int ldb, long long int strideB, const type *beta, type *C,\
int ldc, long long int strideC, int batchCount) {\
return cublas##short_type##gemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount);\
}
GEMM_STRIDED_BATCHED_OP_GEMM(H, half);
GEMM_STRIDED_BATCHED_OP_GEMM(S, float);
GEMM_STRIDED_BATCHED_OP_GEMM(D, double);
GEMM_STRIDED_BATCHED_OP_GEMM(C, cuComplex);
GEMM_STRIDED_BATCHED_OP_GEMM(Z, cuDoubleComplex);
// -----------------------------------------------------
// op_gemmEx
// -----------------------------------------------------
#define GEMM_STRIDED_BATCHED_OP_GEMMEX(cuda_data_type, type)\
template <>\
cublasStatus_t gemm_strided_batched<type , op_gemmEx>(cublasHandle_t handle, cublasOperation_t transa,\
cublasOperation_t transb, int m, int n, int k,\
const type *alpha, const type *A, int lda, long long int strideA,\
const type *B, int ldb, long long int strideB, const type *beta, type *C,\
int ldc, long long int strideC, int batchCount) {\
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, reinterpret_cast<const void*>(A), cuda_data_type, lda, strideA, reinterpret_cast<const void*>(B), cuda_data_type, ldb, strideB, beta, reinterpret_cast<void*>(C), cuda_data_type, ldc, strideC, batchCount, cuda_data_type, CUBLAS_GEMM_DEFAULT);\
}
GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_R_16F, half);
GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_R_32F, float);
GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_R_64F, double);
GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_C_32F, cuComplex);
GEMM_STRIDED_BATCHED_OP_GEMMEX(CUDA_C_64F, cuDoubleComplex);

// -------------
// Gemv
// -------------
Expand Down Expand Up @@ -426,6 +470,42 @@ void gemm_batched_test() {
cudaFreeHost(mat_c_array);
}

template <class T, class Op>
void gemm_strided_batched_test() {
const int n = 1lu << 7;
const int batch_size = 1u << 10;
const auto alpha = convert<T>(1);
const auto beta = convert<T>(0);

T* mat_a;
T* mat_b;
T* mat_c;

cudaMallocHost(&mat_a, sizeof(T) * n * n * batch_size);
cudaMallocHost(&mat_b, sizeof(T) * n * n * batch_size);
cudaMallocHost(&mat_c, sizeof(T) * n * n * batch_size);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

gemm_strided_batched<T, Op>(
cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, n, n,
&alpha,
mat_a, n, n * n,
mat_b, n, n * n,
&beta,
mat_c, n, n * n,
batch_size
);

cublasDestroy(cublas_handle);
cudaFreeHost(mat_a);
cudaFreeHost(mat_b);
cudaFreeHost(mat_c);
}

template <class T>
void gemv_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -748,6 +828,17 @@ void test_all() {
gemm_batched_test<cuComplex , op_gemmEx>();
gemm_batched_test<cuDoubleComplex, op_gemmEx>();

gemm_strided_batched_test<double , op_gemm >();
gemm_strided_batched_test<float , op_gemm >();
gemm_strided_batched_test<half , op_gemm >();
gemm_strided_batched_test<cuComplex , op_gemm >();
gemm_strided_batched_test<cuDoubleComplex, op_gemm >();
gemm_strided_batched_test<double , op_gemmEx>();
gemm_strided_batched_test<float , op_gemmEx>();
gemm_strided_batched_test<half , op_gemmEx>();
gemm_strided_batched_test<cuComplex , op_gemmEx>();
gemm_strided_batched_test<cuDoubleComplex, op_gemmEx>();

gemv_test<double >();
gemv_test<float >();
gemv_test<cuComplex >();
Expand Down

0 comments on commit 38c28dd

Please sign in to comment.