Skip to content

Commit

Permalink
Merge branch '28-hemm' into 'master'
Browse files Browse the repository at this point in the history
Add hemm

See merge request mutsuki/CULiP!31
  • Loading branch information
enp1s0 committed Nov 11, 2021
2 parents 324d90b + 3e782d0 commit 349f376
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 5 deletions.
13 changes: 8 additions & 5 deletions docs/cublas.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
- `cublasCgemmStridedBatched`
- `cublasZgemmStridedBatched`
- `cublasGemmStridedExBatched`
- SYMM
- `cublasDsymm`
- `cublasSsymm`
- `cublasCsymm`
- `cublasZsymm`
- GEMM3M
- `cublasHgemm3m`
- `cublasCgemm3m`
Expand Down Expand Up @@ -64,8 +69,6 @@
- `cublasStrsmBatched`
- `cublasCtrsmBatched`
- `cublasZtrsmBatched`
- SYMM
- `cublasDsymm`
- `cublasSsymm`
- `cublasCsymm`
- `cublasZsymm`
- HEMM
- `cublasChemm`
- `cublasZhemm`
2 changes: 2 additions & 0 deletions include/CULiP/cublas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ enum CULiP_cublas_control_t {
CULiP_cublasStrsmBatched,
CULiP_cublasCtrsmBatched,
CULiP_cublasZtrsmBatched,
CULiP_cublasChemm,
CULiP_cublasZhemm,
CULiP_cublas_enum_length
};

Expand Down
20 changes: 20 additions & 0 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -779,4 +779,24 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE

// -------------------------------------------------
// HEMM
// -------------------------------------------------

#define CULIP_FUNC_NAME cublasChemm
#define CULIP_FUNC_ENUM_NAME CULiP_cublasChemm
#define CULIP_TYPE cuComplex
#include "cublas.hemm.template.h"
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE

#define CULIP_FUNC_NAME cublasZhemm
#define CULIP_FUNC_ENUM_NAME CULiP_cublasZhemm
#define CULIP_TYPE cuDoubleComplex
#include "cublas.hemm.template.h"
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE
} // extern "C"
40 changes: 40 additions & 0 deletions src/cublas.hemm.template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle,
cublasSideMode_t side, cublasFillMode_t uplo,
int m, int n,
const CULIP_TYPE *alpha, const CULIP_TYPE *A, int lda,
const CULIP_TYPE *B, int ldb, const CULIP_TYPE *beta, CULIP_TYPE *C,
int ldc) {
const int profiling_flag = (CULiP_profiling_control_array[CULIP_FUNC_ENUM_NAME] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const CULIP_TYPE*, const CULIP_TYPE*, int, const CULIP_TYPE*, int, const CULIP_TYPE*, CULIP_TYPE*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-%s-%s-m%d-n%d", __func__, CULiP_get_cublasSideMode_t_string(side), CULiP_get_cublasFillMode_t_string(uplo), m, n);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}
60 changes: 60 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,29 @@ GEMM_OP_TRSM_BATCHED(D, double);
GEMM_OP_TRSM_BATCHED(C, cuComplex);
GEMM_OP_TRSM_BATCHED(Z, cuDoubleComplex);

// -----------------------------------------------------
// hemm
// -----------------------------------------------------
template <class T>
cublasStatus_t hemm(cublasHandle_t handle,
cublasSideMode_t side, cublasFillMode_t uplo,
int m, int n,
const T *alpha, const T *A, int lda,
const T *B, int ldb, const T *beta, T *C,
int ldc);
#define GEMM_OP_HEMM(short_type, type)\
template <>\
cublasStatus_t hemm<type>(cublasHandle_t handle,\
cublasSideMode_t side, cublasFillMode_t uplo, \
int m, int n, \
const type *alpha, const type *A, int lda,\
const type *B, int ldb, const type *beta, type *C,\
int ldc) {\
return cublas##short_type##hemm(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);\
}
GEMM_OP_HEMM(C, cuComplex);
GEMM_OP_HEMM(Z, cuDoubleComplex);

// -------------
// Gemm3m
// -------------
Expand Down Expand Up @@ -846,6 +869,40 @@ void trsm_batched_test() {
cudaFreeHost(mat_b_array);
}

template <class T>
void hemm_test() {
const std::size_t n = 1lu << 10;
const auto alpha = convert<T>(1);
const auto beta = convert<T>(0);

T* mat_a;
T* mat_b;
T* mat_c;

cudaMalloc(&mat_a, sizeof(T) * n * n);
cudaMalloc(&mat_b, sizeof(T) * n * n);
cudaMalloc(&mat_c, sizeof(T) * n * n);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

hemm<T>(
cublas_handle,
CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_LOWER,
n, n,
&alpha,
mat_a, n,
mat_b, n,
&beta,
mat_c, n
);

cublasDestroy(cublas_handle);
cudaFree(mat_a);
cudaFree(mat_b);
cudaFree(mat_c);
}

template <class T>
void gemm3m_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -959,6 +1016,9 @@ void test_all() {
trsm_batched_test<cuComplex >();
trsm_batched_test<cuDoubleComplex>();

hemm_test<cuComplex >();
hemm_test<cuDoubleComplex>();

gemm3m_test<cuComplex >();
gemm3m_test<cuDoubleComplex>();
}
Expand Down

0 comments on commit 349f376

Please sign in to comment.