Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KV-Cache int8 quant support #10354

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
77 changes: 70 additions & 7 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@

#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include <string>
#include <cstdint>
#include "dtype_fp8.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"

#ifdef USE_ROCM
#include <hip/hip_bf16.h>
Expand Down Expand Up @@ -85,6 +89,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_LAYER_LEVEL,
YanyunDuanIEI marked this conversation as resolved.
Show resolved Hide resolved
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
Expand All @@ -105,7 +110,11 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float k_scale, const float v_scale,
const int quant_group,
const float* __restrict__ k_scaling_factor,
const float* __restrict__ v_scaling_factor,
const int tp_rank,
YanyunDuanIEI marked this conversation as resolved.
Show resolved Hide resolved
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
Expand Down Expand Up @@ -280,6 +289,26 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// int8 kv-cache
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
if constexpr (IS_LAYER_LEVEL){
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant, k_scale, 0);
}else{
int64_t tgt_ks_idx = floor((kv_head_idx*HEAD_SIZE)/quant_group)
+ floor((physical_block_offset * x
+ offset1 * BLOCK_SIZE * x
+ offset2)/(quant_group*BLOCK_SIZE));
float k_scale_int8 =
*reinterpret_cast<const float*>(k_scaling_factor + tgt_ks_idx);
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant, k_scale_int8, 0);
}
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
Expand Down Expand Up @@ -410,6 +439,27 @@ __device__ void paged_attention_kernel(

if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
// int8 kv-cache
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
if constexpr (IS_LAYER_LEVEL){
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
v_scale,
0);
}else{
const int64_t tgt_vs_idx = floor((kv_head_idx*HEAD_SIZE)/quant_group)
+ floor(offset/(quant_group*BLOCK_SIZE));
float v_scale_int8 =
*reinterpret_cast<const float*>(v_scaling_factor + tgt_vs_idx);
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
v_scale_int8,
0);
}
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
Expand Down Expand Up @@ -498,6 +548,7 @@ __device__ void paged_attention_kernel(
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_LAYER_LEVEL,
bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
Expand All @@ -513,22 +564,29 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float k_scale, const float v_scale,
const int quant_group,
const float* __restrict__ k_scaling_factor,
const float* __restrict__ v_scaling_factor,
const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
KV_DTYPE, IS_LAYER_LEVEL, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
kv_head_stride, k_scale, v_scale,
quant_group, k_scaling_factor, v_scaling_factor,
tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_LAYER_LEVEL,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
Expand All @@ -549,14 +607,19 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float k_scale, const float v_scale,
const int quant_group,
const float* __restrict__ k_scaling_factor,
const float* __restrict__ v_scaling_factor,
const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
KV_DTYPE, IS_LAYER_LEVEL, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
kv_block_stride, kv_head_stride, k_scale, v_scale,
quant_group, k_scaling_factor, v_scaling_factor, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
Expand Down
4 changes: 4 additions & 0 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ template <>
struct FloatVec<uint4> {
using Type = Float8_;
};
template<>
struct FloatVec<uint8_t> {
using Type = float;
};

// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
Expand Down
1 change: 1 addition & 0 deletions csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
kInt8 = 3,
};

// fp8 vector types for quantization of kv cache
Expand Down
60 changes: 46 additions & 14 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,40 @@
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
KV_DTYPE, \
IS_LAYER_LEVEL, \
IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
NUM_THREADS, KV_DTYPE, \
IS_LAYER_LEVEL, \
IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, \
quant_group, k_scaling_factor_ptr, v_scaling_factor_ptr, \
tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_LAYER_LEVEL,
bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
float v_scale,
int quant_group,
torch::Tensor& k_scaling_factor,
torch::Tensor& v_scaling_factor,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand All @@ -78,6 +90,8 @@ void paged_attention_v1_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
float* k_scaling_factor_ptr = reinterpret_cast<float*>(k_scaling_factor.data_ptr());
float* v_scaling_factor_ptr = reinterpret_cast<float*>(v_scaling_factor.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

Expand Down Expand Up @@ -131,22 +145,36 @@ void paged_attention_v1_launcher(
}
}

#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_LAYER_LEVEL, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_LAYER_LEVEL, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, \
quant_group, k_scaling_factor, v_scaling_factor, \
tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
case true: \
switch (is_layer_level) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false, true); \
break; \
} \
case false: \
switch (is_layer_level) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true, false); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false, false);\
break; \
} \
}

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
Expand Down Expand Up @@ -181,10 +209,14 @@ void paged_attention_v1(
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t quant_group,
torch::Tensor& k_scaling_factor,
torch::Tensor& v_scaling_factor,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
const bool is_layer_level = (quant_group == 0);

DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
Expand Down
Loading