Skip to content

Commit

Permalink
Add support for cutoff to gemmEx*
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Sep 16, 2022
1 parent e189652 commit 9ae7c04
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,40 @@ mtk::cu_exp_statistics::result_t exp_stats(
}
return result;
}

void cutoff(
void* const ptr,
const std::size_t offset,
const std::size_t m,
const std::size_t n,
const std::size_t ld,
double threshold,
const cudaStream_t cuda_stream,
cudaDataType_t data_t
) {
switch (data_t) {
case CUDA_R_64F:
mtk::cu_cutoff::cutoff_small_abs_values(reinterpret_cast<double*>(ptr) + offset, m, n, ld, threshold, cuda_stream);
break;
case CUDA_R_32F:
mtk::cu_cutoff::cutoff_small_abs_values(reinterpret_cast<float*>(ptr) + offset, m, n, ld, threshold, cuda_stream);
break;
case CUDA_R_16F:
mtk::cu_cutoff::cutoff_small_abs_values(reinterpret_cast<half*>(ptr) + offset, m, n, ld, threshold, cuda_stream);
break;
case CUDA_C_64F:
mtk::cu_cutoff::cutoff_small_abs_values(reinterpret_cast<double2*>(ptr) + offset, m, n, ld, threshold, cuda_stream);
break;
case CUDA_C_32F:
mtk::cu_cutoff::cutoff_small_abs_values(reinterpret_cast<float2*>(ptr) + offset, m, n, ld, threshold, cuda_stream);
break;
case CUDA_C_16F:
mtk::cu_cutoff::cutoff_small_abs_values(reinterpret_cast<half2*>(ptr) + offset, m, n, ld, threshold, cuda_stream);
break;
default:
break;
}
}
} // unnamed namespace

extern "C" {
Expand Down Expand Up @@ -275,6 +309,22 @@ cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa,
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&b_stats);
}

const int cutoff_flag = (CULiP_profiling_control_array[CULiP_cublasGemmEx] == 0) && CULiP_is_profiling_enabled(CULIP_CUTOFF_THRESHOLD_ENV_NAME, false);
if (cutoff_flag) {
double threshold;
try {
const auto env_str = getenv(CULIP_CUTOFF_THRESHOLD_ENV_NAME);
threshold = std::stod(env_str);

cudaStream_t cuda_stream;
cublasGetStream(handle, &cuda_stream);
cutoff(const_cast<void*>(A), 0, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, threshold, cuda_stream, Atype);
cutoff(const_cast<void*>(B), 0, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, threshold, cuda_stream, Btype);
} catch(const std::exception& e) {
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Warning] invalid threshold (%s)\n", env_str));
}
}

return result;
}

Expand Down Expand Up @@ -514,6 +564,24 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&b_stats);
}

const int cutoff_flag = (CULiP_profiling_control_array[CULiP_cublasGemmStridedBatchedEx] == 0) && CULiP_is_profiling_enabled(CULIP_CUTOFF_THRESHOLD_ENV_NAME, false);
if (cutoff_flag) {
double threshold;
try {
const auto env_str = getenv(CULIP_CUTOFF_THRESHOLD_ENV_NAME);
threshold = std::stod(env_str);

cudaStream_t cuda_stream;
cublasGetStream(handle, &cuda_stream);
for (std::uint32_t i = 0; i < batchCount; i++) {
cutoff(const_cast<void*>(A), i * strideA, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, threshold, cuda_stream, Atype);
cutoff(const_cast<void*>(B), i * strideB, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, threshold, cuda_stream, Btype);
}
} catch(const std::exception& e) {
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Warning] invalid threshold (%s)\n", env_str));
}
}

return result;
}

Expand Down

0 comments on commit 9ae7c04

Please sign in to comment.