Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C] Separating cudnn common utils from fused_attn #1314

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
cudnn_utils.cpp
transformer_engine.cpp
common.cu
Expand Down
59 changes: 58 additions & 1 deletion transformer_engine/common/cudnn_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,70 @@
* See LICENSE for license information.
************************************************************************/

#include "../fused_attn/utils.h"
#include "cudnn_utils.h"

#include "./util/logging.h"
#include "transformer_engine/cudnn.h"

namespace transformer_engine {

// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}

// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}

void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}

} // namespace transformer_engine

namespace cudnn_frontend {

// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;

} // namespace cudnn_frontend
46 changes: 46 additions & 0 deletions transformer_engine/common/cudnn_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_

#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>

#include <cstdint>
#include <mutex>

#include "transformer_engine/transformer_engine.h"

namespace transformer_engine {

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);

cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);

class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}

cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}

~cudnnExecutionPlanManager() {}

private:
cudnnHandle_t handle_ = nullptr;
};

} // namespace transformer_engine

#endif
1 change: 1 addition & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "transformer_engine/fused_attn.h"

#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <vector>

#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <vector>

#include "../common.h"
#include "../cudnn_utils.h"
#include "fused_attn_f16_max512_seqlen.h"
#include "utils.h"

Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
************************************************************************/

#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/system.h"
#include "fused_attn_fp8.h"
#include "utils.h"
Expand Down
47 changes: 1 addition & 46 deletions transformer_engine/common/fused_attn/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cmath>

#include "../common.h"
#include "../cudnn_utils.h"
#include "transformer_engine/fused_attn.h"
#include "utils.h"

Expand Down Expand Up @@ -495,50 +496,4 @@ size_t get_max_tokens(size_t num_tokens) {
}

} // namespace fused_attn

// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}

// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine
23 changes: 1 addition & 22 deletions transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,29 +140,8 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at

size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
} // namespace fused_attn

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);

class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}

cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}

~cudnnExecutionPlanManager() {}

private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace fused_attn
} // namespace transformer_engine

#endif
14 changes: 0 additions & 14 deletions transformer_engine/common/pycudnn.cpp

This file was deleted.

Loading