Skip to content

Commit

Permalink
Merge branch '35-spmv' into 'master'
Browse files Browse the repository at this point in the history
Add spmv

See merge request mutsuki/CULiP!37
  • Loading branch information
enp1s0 committed Dec 7, 2021
2 parents c506fee + 5bdce65 commit 1f89c95
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 1 deletion.
5 changes: 4 additions & 1 deletion docs/cublas.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
- `cublasCgeru`
- `cublasZgerc`
- `cublasZgeru`
- GBMV
- SBMV
- `cublasDsbmv`
- `cublasSsbmv`
- SPMV
- `cublasDspmv`
- `cublasSspmv`

## Level 3

Expand Down
2 changes: 2 additions & 0 deletions include/CULiP/cublas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ enum CULiP_cublas_control_t {
CULiP_cublasZgerc,
CULiP_cublasDsbmv,
CULiP_cublasSsbmv,
CULiP_cublasDspmv,
CULiP_cublasSspmv,

// LEVEL 3
CULiP_cublasDgemm,
Expand Down
20 changes: 20 additions & 0 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,26 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE

// -------------------------------------------------
// SPMV
// -------------------------------------------------

#define CULIP_FUNC_NAME cublasSspmv
#define CULIP_FUNC_ENUM_NAME CULiP_cublasSspmv
#define CULIP_TYPE float
#include "cublas.spmv.template.h"
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE

#define CULIP_FUNC_NAME cublasDspmv
#define CULIP_FUNC_ENUM_NAME CULiP_cublasDspmv
#define CULIP_TYPE double
#include "cublas.spmv.template.h"
#undef CULIP_FUNC_NAME
#undef CULIP_FUNC_ENUM_NAME
#undef CULIP_TYPE

// -------------------------------------------------
// SYRK
// -------------------------------------------------
Expand Down
39 changes: 39 additions & 0 deletions src/cublas.spmv.template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle,
cublasFillMode_t uplo, int n,
const CULIP_TYPE *alpha, const CULIP_TYPE *AP,
const CULIP_TYPE *x, int incx, const CULIP_TYPE *beta, CULIP_TYPE *y,
int incy) {
const int profiling_flag = (CULiP_profiling_control_array[CULIP_FUNC_ENUM_NAME] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasFillMode_t, int, const CULIP_TYPE*, const CULIP_TYPE*, const CULIP_TYPE*, int, const CULIP_TYPE*, CULIP_TYPE*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-%s-n%d", __func__, CULiP_get_cublasFillMode_t_string(uplo), n);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, uplo, n, alpha, AP, x, incx, beta, y, incy);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}
58 changes: 58 additions & 0 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,27 @@ cublasStatus_t sbmv<type>(cublasHandle_t handle, cublasFillMode_t uplo,\
GEMM_OP_SBMV(S, float);
GEMM_OP_SBMV(D, double);

// -------------
// Spmv
// -------------
template <class T>
cublasStatus_t spmv(cublasHandle_t handle, cublasFillMode_t uplo,
int m,
const T *alpha, const T *A,
const T *x, int incx, const T *beta, T *y,
int incy);
#define GEMM_OP_SPMV(short_type, type)\
template <>\
cublasStatus_t spmv<type>(cublasHandle_t handle, cublasFillMode_t uplo,\
int m,\
const type *alpha, const type *A,\
const type *x, int incx, const type *beta, type *y,\
int incy) {\
return cublas##short_type##spmv(handle, uplo, m, alpha, A, x, incx, beta, y, incy);\
}
GEMM_OP_SPMV(S, float);
GEMM_OP_SPMV(D, double);

// -------------
// Syrk
// -------------
Expand Down Expand Up @@ -824,6 +845,40 @@ void sbmv_test() {
cudaFree(vec_y);
}

template <class T>
void spmv_test() {
const std::size_t n = 1lu << 10;
const auto alpha = convert<T>(1);
const auto beta = convert<T>(0);

T* mat_a;
T* vec_x;
T* vec_y;

cudaMalloc(&mat_a, sizeof(T) * n * (n + 1) / 2);
cudaMalloc(&vec_x, sizeof(T) * n);
cudaMalloc(&vec_y, sizeof(T) * n);

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);

spmv<T>(
cublas_handle,
CUBLAS_FILL_MODE_UPPER,
n,
&alpha,
mat_a,
vec_x, 1,
&beta,
vec_y, 1
);

cublasDestroy(cublas_handle);
cudaFree(mat_a);
cudaFree(vec_x);
cudaFree(vec_y);
}

template <class T>
void ger_test() {
const std::size_t n = 1lu << 10;
Expand Down Expand Up @@ -1348,6 +1403,9 @@ void test_all() {
sbmv_test<double >();
sbmv_test<float >();

spmv_test<double >();
spmv_test<float >();

gemm_test<double , op_gemm >();
gemm_test<float , op_gemm >();
gemm_test<half , op_gemm >();
Expand Down

0 comments on commit 1f89c95

Please sign in to comment.