Skip to content

Commit

Permalink
Merge branch '30-her2k' into 'master'
Browse files Browse the repository at this point in the history
Add her2k

See merge request mutsuki/CULiP!33
  • Loading branch information
enp1s0 committed Nov 15, 2021
2 parents 8827610 + 14c1991 commit b1c8576
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/cublas.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@
- HERK
- `cublasCherk`
- `cublasZherk`
- HER2K
- `cublasCher2k`
- `cublasZher2k`
6 changes: 4 additions & 2 deletions include/CULiP/cublas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ enum CULiP_cublas_control_t {
CULiP_cublasStrsmBatched,
CULiP_cublasCtrsmBatched,
CULiP_cublasZtrsmBatched,
CULiP_cublasCherk,
CULiP_cublasZherk,
CULiP_cublasChemm,
CULiP_cublasZhemm,
CULiP_cublasCherk,
CULiP_cublasZherk,
CULiP_cublasCher2k,
CULiP_cublasZher2k,
CULiP_cublas_enum_length
};

Expand Down
24 changes: 24 additions & 0 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -822,4 +822,28 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE

// -------------------------------------------------
// HERK
// -------------------------------------------------
#define CULIP_FUNC_NAME cublasCher2k
#define CULIP_FUNC_ENUM_NAME CULiP_cublasCher2k
#define CULIP_TYPE cuComplex
#define CULIP_REAL_TYPE float
#include "cublas.her2k.template.h"
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE
#undef CULIP_REAL_TYPE

#define CULIP_FUNC_NAME cublasZher2k
#define CULIP_FUNC_ENUM_NAME CULiP_cublasZher2k
#define CULIP_TYPE cuDoubleComplex
#define CULIP_REAL_TYPE double
#include "cublas.her2k.template.h"
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE
#undef CULIP_REAL_TYPE

} // extern "C"
42 changes: 42 additions & 0 deletions src/cublas.her2k.template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle,
cublasFillMode_t uplo, cublasOperation_t trans,
int n, int k,
const CULIP_TYPE *alpha,
const CULIP_TYPE *A, int lda,
const CULIP_TYPE *B, int ldb,
const CULIP_REAL_TYPE *beta,
CULIP_TYPE *C, int ldc) {
const int profiling_flag = (CULiP_profiling_control_array[CULIP_FUNC_ENUM_NAME] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const CULIP_TYPE*, const CULIP_TYPE*, int, const CULIP_TYPE*, int, const CULIP_REAL_TYPE*, CULIP_TYPE*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-%s-%s-n%d-k%d", __func__, CULiP_get_cublasFillMode_t_string(uplo), CULiP_get_cublasOperation_t_string(trans), n, k);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}
65 changes: 65 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,33 @@ cublasStatus_t herk<type, real_type>(cublasHandle_t handle,\
GEMM_OP_HERK(C, cuComplex, float);
GEMM_OP_HERK(Z, cuDoubleComplex, double);

// -----------------------------------------------------
// her2k
// -----------------------------------------------------
template <class T, class RT>
cublasStatus_t her2k(cublasHandle_t handle,
cublasFillMode_t uplo, cublasOperation_t trans,
int m, int n,
const T *alpha,
const T *A, int lda,
const T *B, int ldb,
const RT *beta, T *C,
int ldc);
#define GEMM_OP_HER2K(short_type, type, real_type)\
template <>\
cublasStatus_t her2k<type, real_type>(cublasHandle_t handle,\
cublasFillMode_t uplo, cublasOperation_t trans, \
int m, int n, \
const type *alpha,\
const type *A, int lda,\
const type *B, int ldb,\
const real_type *beta, type *C,\
int ldc) {\
return cublas##short_type##her2k(handle, uplo, trans, m, n, alpha, A, lda, B, ldb, beta, C, ldc);\
}
GEMM_OP_HER2K(C, cuComplex, float);
GEMM_OP_HER2K(Z, cuDoubleComplex, double);

// -------------
// Gemm3m
// -------------
Expand Down Expand Up @@ -964,6 +991,41 @@ void herk_test() {
cudaFree(mat_c);
}

template <class T>
void her2k_test() {
using real_type = typename get_real_type<T>::type;
const std::size_t n = 1lu << 10;
const auto alpha = convert<T>(1);
const auto beta = convert<real_type>(0);

T* mat_a;
T* mat_b;
T* mat_c;

cudaMalloc(&mat_a, sizeof(T) * n * n);
cudaMalloc(&mat_b, sizeof(T) * n * n);
cudaMalloc(&mat_c, sizeof(T) * n * n);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

her2k<T>(
cublas_handle,
CUBLAS_FILL_MODE_LOWER, CUBLAS_OP_N,
n, n,
&alpha,
mat_a, n,
mat_b, n,
&beta,
mat_c, n
);

cublasDestroy(cublas_handle);
cudaFree(mat_a);
cudaFree(mat_b);
cudaFree(mat_c);
}

template <class T>
void gemm3m_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -1083,6 +1145,9 @@ void test_all() {
herk_test<cuComplex >();
herk_test<cuDoubleComplex>();

her2k_test<cuComplex >();
her2k_test<cuDoubleComplex>();

gemm3m_test<cuComplex >();
gemm3m_test<cuDoubleComplex>();
}
Expand Down

0 comments on commit b1c8576

Please sign in to comment.