Skip to content

Commit

Permalink
refactor - quant algo
Browse files Browse the repository at this point in the history
  • Loading branch information
rtp-llm authored and baowending.bwd committed Apr 20, 2024
1 parent a2c17e6 commit 65097e0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/fastertransformer/layers/FfnLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void FfnLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_tensors, c

print_bsd(layer_id, "moe gate", moe_gates_buf_, 1, m, expert_num_);

if (quant_algo_.int8Mode() == 1) {
if (quant_algo_.int8Mode()) {

moe_plugin_->enqueue(input_tensor,
moe_gates_buf_,
Expand Down Expand Up @@ -189,7 +189,7 @@ void FfnLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_tensors, c
constexpr int m_padded = 0;
#endif

const bool is_quant_mode = quant_algo_.int8Mode() == 1 || quant_algo_.int4Mode();
const bool is_quant_mode = quant_algo_.int8Mode() || quant_algo_.int4Mode();
const int cur_inter_size = is_quant_mode ? inter_padding_size : inter_size;
// gemm used inter_size, int8 use inter_padding_size
gemm_runner_->Gemm(m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ void ParallelAttentionWrapper<T>::DenseGemm(const int h_token_nu
qkv_buf_3_input = qkv_buf_2_;
}
print_bsd(layer_id, "attn before o", qkv_buf_3_input, h_token_num, 1, local_hidden_units_rt);
if(params_.quant_algo_->sq_int8_){
if(quant_algo_.smoothQuantInt8()){
FT_CHECK_WITH_INFO(attention_weights->attention_output_weight.smoother != nullptr, "smoother is needed in sq dynamic token");
invokePerTokenQuantization(reinterpret_cast<int8_t*>(qkv_buf_), qkv_buf_2_, h_token_num, local_hidden_units_rt, dense_gemm_dynamic_scale_, attention_weights->attention_output_weight.smoother, stream_);
qkv_buf_3_input = qkv_buf_;
Expand Down Expand Up @@ -550,9 +550,9 @@ void ParallelAttentionWrapper<T>::SelfAttention(TensorMap* output
relative_attention_bias_stride,
input_tensors->getPtr<T>("linear_bias_slopes", nullptr),
input_tensors->getPtr<bool>("masked_tokens", nullptr),
params_.quant_algo_->int8_mode_ == 2 ? attention_weights->query_weight.scale_out : nullptr,
params_.quant_algo_->int8_mode_ == 2 ? attention_weights->attention_output_weight.scale : nullptr,
params_.quant_algo_->int8_mode_,
nullptr,
nullptr,
0,
multi_block_mode_,
max_seq_len_tile_,
partial_out_,
Expand Down Expand Up @@ -665,7 +665,7 @@ void ParallelAttentionWrapper<T>::ContextAttention(TensorMap* out
params_.logn_seq_len_,
params_.use_logn_attn_,
attention_weights->query_weight.scale_out,
params_.quant_algo_->int8_mode_,
0,
stream_);

}
Expand Down Expand Up @@ -889,7 +889,7 @@ void ParallelAttentionWrapper<T>::ContextAttention(TensorMap* out
local_head_num,
params_.size_per_head_,
attention_weights->attention_output_weight.scale,
params_.quant_algo_->int8_mode_,
0,
stream_);
sync_check_cuda_error();
}
Expand All @@ -903,7 +903,7 @@ void ParallelAttentionWrapper<T>::ContextAttention(TensorMap* out
params_.size_per_head_,
padding_offset,
attention_weights->attention_output_weight.scale,
params_.quant_algo_->int8_mode_,
0,
stream_);
}
POP_RANGE;
Expand All @@ -925,7 +925,7 @@ void ParallelAttentionWrapper<T>::Attention(TensorMap* output_ten
int max_context_seq_len_with_prefix = 0;
const int layer_id = input_tensors->getVal<int>("layer_id");
const int h_token_num = input_tensors->at("input_query").shape()[0];
const float* attn_dynamic_scale = params_.quant_algo_->sq_int8_ ? input_tensors->at("attn_dynamic_scale").getPtr<float>() : nullptr;
const float* attn_dynamic_scale = quant_algo_.smoothQuantInt8() ? input_tensors->at("attn_dynamic_scale").getPtr<float>() : nullptr;

// lora
int* lora_ids = input_tensors->getPtr<int>("lora_ids", nullptr);
Expand Down Expand Up @@ -991,16 +991,14 @@ ParallelAttentionWrapper<T>::ParallelAttentionWrapper(const GptInitParameter& gp
local_hidden_units_(gpt_init_parameter.hidden_size_ / tensor_para.world_size_),
is_qk_buf_float_(is_qk_buf_float),
lora_gemm_(std::make_shared<LoraGemm<T>>(stream, allocator, cublas_wrapper)),
quant_algo_(quant_algo),
gemm_runner_(std::make_shared<GemmRunner<T>>(stream, allocator, cublas_wrapper, quant_algo)),
local_layer_head_num_(getLocalParameter(gpt_init_parameter.layer_head_num_, tensor_para.world_size_)),
local_layer_head_num_kv_(getLocalParameter(gpt_init_parameter.layer_head_num_kv_, tensor_para.world_size_)),
q_scaling_(gpt_init_parameter.q_scaling_),
tensor_para_(tensor_para) {
multi_block_mode_ = UseMultiBlockMode();

if (params_.quant_algo_->int8_mode_ == 2) {
abort();
}
FT_LOG_DEBUG(__PRETTY_FUNCTION__);

tensorrt_llm::kernels::Data_type data_type;
Expand Down Expand Up @@ -1100,7 +1098,7 @@ void ParallelAttentionWrapper<T>::allocateBuffer(
}
}

if(params_.quant_algo_->sq_int8_){
if(quant_algo_.smoothQuantInt8()){
dense_gemm_dynamic_scale_ = (float*)allocator_->reMalloc(dense_gemm_dynamic_scale_, sizeof(float)*h_token_num, false);
}

Expand Down Expand Up @@ -1138,7 +1136,7 @@ void ParallelAttentionWrapper<T>::freeBuffer()
allocator_->free((void**)(&block_counter_));
}

if(params_.quant_algo_->sq_int8_){
if(quant_algo_.smoothQuantInt8()){
allocator_->free((void**)(&dense_gemm_dynamic_scale_));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ParallelAttentionWrapper: public BaseAttentionLayer<T> {
bool is_qk_buf_float_;
std::shared_ptr<LoraGemm<T>> lora_gemm_;
std::shared_ptr<GemmRunner<T>> gemm_runner_;
tc::QuantAlgo quant_algo_;

bool multi_block_mode_ = false;
// for sparse
Expand Down
14 changes: 7 additions & 7 deletions src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void ParallelGpt<T>::allocateBuffer(size_t total_batch_size, size_t h_token_num,
// only allocate additionl buffers when has adapters
decoder_layer_output_ = reinterpret_cast<T*>(
allocator_->reMalloc(decoder_layer_output_, sizeof(T) * h_token_num * hidden_units, false));
if (params_.quant_algo_->sq_int8_) {
if (quant_algo_.smoothQuantInt8()) {
attention_query_dynamic_scale_ = reinterpret_cast<float*>(
allocator_->reMalloc(attention_query_dynamic_scale_, sizeof(float) * h_token_num, false));
ffn_intermediate_dynamic_scale_ = reinterpret_cast<float*>(
Expand Down Expand Up @@ -157,7 +157,7 @@ void ParallelGpt<T>::freeBuffer()
allocator_->free((void**)(&prefix_lengths_));
allocator_->free((void**)(&block_pointers_));
allocator_->free((void**)(&block_scale_pointers_));
if (params_.quant_algo_->sq_int8_) {
if (quant_algo_.smoothQuantInt8()) {
allocator_->free((void**)(&attention_query_dynamic_scale_));
allocator_->free((void**)(&ffn_intermediate_dynamic_scale_));
}
Expand Down Expand Up @@ -389,7 +389,7 @@ void ParallelGpt<T>::forward(TensorMap*
}
}

const auto activation_in_type = params_.quant_algo_->sq_int8_ ? TYPE_INT8 : data_type;
const auto activation_in_type = quant_algo_.smoothQuantInt8() ? TYPE_INT8 : data_type;
const auto activation_out_type = data_type;

size_t context_h_token_num = h_token_num - batch_size;
Expand Down Expand Up @@ -448,7 +448,7 @@ void ParallelGpt<T>::forward(TensorMap*
attention_query_dynamic_scale_,
reinterpret_cast<int8_t*>(decoder_normed_input_),
stream_);
if (params_.quant_algo_->sq_int8_) {
if (quant_algo_.smoothQuantInt8()) {
print_bsd(l, "pre ln", reinterpret_cast<int8_t*>(decoder_normed_input_), 1, h_token_num, hidden_units);
} else {
print_bsd(l, "pre ln", decoder_normed_input_, 1, h_token_num, hidden_units);
Expand Down Expand Up @@ -502,7 +502,7 @@ void ParallelGpt<T>::forward(TensorMap*
{"lora_ids", input_tensors->at("lora_ids")},
{"lora_input_lengths", input_tensors->at("lora_input_lengths")}};

if (params_.quant_algo_->sq_int8_) {
if (quant_algo_.smoothQuantInt8()) {
FT_CHECK_WITH_INFO(attention_query_dynamic_scale_!=nullptr, "attention_query_dynamic_scale_ should not be nullptr");
attention_input_tensors.insert(
"attn_dynamic_scale", Tensor{MEMORY_GPU, TYPE_FP32, {h_token_num, 1}, attention_query_dynamic_scale_});
Expand Down Expand Up @@ -598,7 +598,7 @@ void ParallelGpt<T>::forward(TensorMap*
int ffn_batch_size_lora = batch_size + context_batch_size;
const int* lora_input_lengths = input_tensors->getPtr<int>("lora_input_lengths", nullptr);;

if (params_.quant_algo_->sq_int8_) {
if (quant_algo_.smoothQuantInt8()) {
print_bsd(l,
"before ffn",
params_.layernorm_type_ == LayerNormType::pre_layernorm ?
Expand Down Expand Up @@ -628,7 +628,7 @@ void ParallelGpt<T>::forward(TensorMap*
{"lora_ids", input_tensors->at("lora_ids")},
{"lora_input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {total_batch_size}, lora_input_lengths}},
{"batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {(size_t)1}, &ffn_batch_size_lora}}});
if(params_.quant_algo_->sq_int8_){
if(quant_algo_.smoothQuantInt8()){
FT_CHECK_WITH_INFO(ffn_intermediate_dynamic_scale_ != nullptr, "ffn_dynamic_scale should not be nullptr");
ffn_input_tensors.insert("ffn_dynamic_scale", Tensor{MEMORY_GPU, TYPE_FP32, {h_token_num, 1}, ffn_intermediate_dynamic_scale_});
}
Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/utils/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class QuantAlgo{

public:
QuantAlgo() = default;
QuantAlgo(int int8_mode, bool int4_mode, bool use_zeros, int64_t group_size, bool sq_int8):
QuantAlgo(bool int8_mode, bool int4_mode, bool use_zeros, int64_t group_size, bool sq_int8):
int8_mode_(int8_mode),
int4_mode_(int4_mode),
use_zeros_(use_zeros),
Expand Down

0 comments on commit 65097e0

Please sign in to comment.