diff --git a/src/cublas.cu b/src/cublas.cu index 8050316..1e05859 100644 --- a/src/cublas.cu +++ b/src/cublas.cu @@ -445,4 +445,40 @@ cublasStatus_t cublasGemmBatchedEx(cublasHandle_t handle, #undef CULIP_FUNC_NAME #undef CULIP_FUNC_ENUM_NAME #undef CULIP_TYPE + +// ------------------------------------------------- +// SYMM +// ------------------------------------------------- + +#define CULIP_FUNC_NAME cublasSsymm +#define CULIP_FUNC_ENUM_NAME CULiP_cublasSsymm +#define CULIP_TYPE float +#include "cublas.symm.template.h" +#undef CULIP_FUNC_NAME +#undef CULIP_FUNC_ENUM_NAME +#undef CULIP_TYPE + +#define CULIP_FUNC_NAME cublasDsymm +#define CULIP_FUNC_ENUM_NAME CULiP_cublasDsymm +#define CULIP_TYPE double +#include "cublas.symm.template.h" +#undef CULIP_FUNC_NAME +#undef CULIP_FUNC_ENUM_NAME +#undef CULIP_TYPE + +#define CULIP_FUNC_NAME cublasCsymm +#define CULIP_FUNC_ENUM_NAME CULiP_cublasCsymm +#define CULIP_TYPE cuComplex +#include "cublas.symm.template.h" +#undef CULIP_FUNC_NAME +#undef CULIP_FUNC_ENUM_NAME +#undef CULIP_TYPE + +#define CULIP_FUNC_NAME cublasZsymm +#define CULIP_FUNC_ENUM_NAME CULiP_cublasZsymm +#define CULIP_TYPE cuDoubleComplex +#include "cublas.symm.template.h" +#undef CULIP_FUNC_NAME +#undef CULIP_FUNC_ENUM_NAME +#undef CULIP_TYPE } // extern "C" diff --git a/src/cublas.symm.template.h b/src/cublas.symm.template.h new file mode 100644 index 0000000..a5bc289 --- /dev/null +++ b/src/cublas.symm.template.h @@ -0,0 +1,39 @@ +cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, int m, int n, + const CULIP_TYPE *alpha, const CULIP_TYPE *A, int lda, + const CULIP_TYPE *B, int ldb, const CULIP_TYPE *beta, CULIP_TYPE *C, + int ldc) { + const int profiling_flag = (CULiP_profiling_control_array[CULIP_FUNC_ENUM_NAME] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME); + + // Get the function pointer + cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const CULIP_TYPE*, const CULIP_TYPE*, int, const CULIP_TYPE*, int, const CULIP_TYPE*, CULIP_TYPE*, int); + *(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache); + + cudaStream_t cuda_stream; + struct CULiP_profile_result profile_result; + + if (profiling_flag) { + // Get current cuda stream + cublasGetStream(handle, &cuda_stream); + + // Profile result structure + snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-%s%s-m%d-n%d", __func__, CULiP_get_cublasSideMode_t_string(side), CULiP_get_cublasFillMode_t_string(uplo), m, n); + + // Record start rimestamp + CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp); + } + + // Call the function + const cublasStatus_t result = (*cublas_lib_func)(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__)); + + if (profiling_flag) { + // Record end rimestamp + CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp); + + // Print result + CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result); + } + + return result; +}