diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 90a5e54736cf3..41d9e682572a6 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 100 +MAX_SIZE_MB = 150 def print_top_10_largest_files(zip_file): diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index c04e05a994894..ce508e4748aba 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -40,5 +40,5 @@ docker run \ -e HF_TOKEN \ --name ${container_name} \ ${container_name} \ - /bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//") + /bin/bash -c "${@}" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e49a5650c44ea..2eeba904a209d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -48,7 +48,7 @@ steps: - pytest -v -s test_pynccl.py - label: Engine Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test @@ -73,13 +73,13 @@ steps: parallelism: 4 - label: Models Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py + - pytest -v -s models --ignore=models/test_llava.py - label: Llava Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py @@ -101,7 +101,7 @@ steps: command: pytest -v -s worker - label: Speculative decoding tests - mirror_hardwares: [amd] + #mirror_hardwares: [amd] command: pytest -v -s spec_decode - label: LoRA Test %N diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index ea02b6b1e9c9e..174c756ae74a3 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -14,6 +14,8 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 - wait - group: "AMD Tests" @@ -24,7 +26,7 @@ steps: - label: "AMD: {{ step.label }}" agents: queue: amd - command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" + command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}" env: DOCKER_BUILDKIT: "1" {% endif %} @@ -53,6 +55,8 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 plugins: - kubernetes: podSpec: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ac60ce0fed14a..9c35ede5f6781 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -58,6 +58,9 @@ jobs: - name: Setup ccache uses: hendrikmuhs/ccache-action@v1.2 + with: + create-symlink: true + key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - name: Set up Linux Env if: ${{ runner.os == 'Linux' }} diff --git a/CMakeLists.txt b/CMakeLists.txt index f817f3382c5e1..1c7dfe0c048b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,7 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/fp8/fp8_cuda_kernels.cu" + "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") @@ -219,7 +219,8 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cc") + "csrc/punica/punica_ops.cu" + "csrc/punica/punica_pybind.cpp") # # Copy GPU compilation flags+update for punica @@ -243,6 +244,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") endif() endforeach() message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") +elseif(${VLLM_GPU_LANG} STREQUAL "HIP") + set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) + message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") endif() if (VLLM_PUNICA_GPU_ARCHES) @@ -277,11 +281,6 @@ add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) -endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and @@ -292,3 +291,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") add_dependencies(default _punica_C) endif() endif() + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) +endif() diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..ddca95c0e8786 100644 --- a/Dockerfile +++ b/Dockerfile @@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.8 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base @@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### diff --git a/Dockerfile.rocm b/Dockerfile.rocm index d04bb9915e2ab..eefad79e79d83 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -94,6 +94,9 @@ COPY . . RUN python3 -m pip install --upgrade pip numba +# make sure punica kernels are built (for LoRA) +ENV VLLM_INSTALL_PUNICA_KERNELS=1 + RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 7c71673e36f29..00c81e4d00ad8 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(APPEND GPU_FLAGS "-DENABLE_FP8") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8_E4M3" + "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8b1b5e098015f..41b337dd91d36 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -19,21 +19,17 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" -#if defined(ENABLE_FP8_E5M2) -#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "../quantization/fp8/amd_detail/quant_utils.cuh" -#endif - -#include - #ifdef USE_ROCM #include + #include "../quantization/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM @@ -92,7 +88,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -157,9 +153,7 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; -#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -223,21 +217,14 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. - k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); -#elif defined(ENABLE_FP8_E4M3) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k - // cache vec to k vec in higher precision (FP16, BFloat16, etc.) - k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); -#else - assert(false); -#endif - } else { + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); } } @@ -312,9 +299,7 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; -#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; @@ -348,21 +333,13 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); -#elif defined(ENABLE_FP8_E4M3) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert - // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) - v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); -#else - assert(false); -#endif - } else { - v_vec = *reinterpret_cast(v_ptr + offset); + v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, @@ -448,7 +425,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE> + vllm::Fp8KVCacheDataType KV_DTYPE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -464,7 +441,7 @@ __global__ void paged_attention_v1_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -477,7 +454,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -496,7 +473,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -606,9 +583,9 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + KV_DTYPE>), shared_mem_size); \ vllm::paged_attention_v1_kernel<<>>( \ + KV_DTYPE><<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -629,7 +606,7 @@ template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -706,36 +683,36 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( @@ -752,65 +729,44 @@ void paged_attention_v1( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -897,39 +853,39 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( @@ -949,29 +905,7 @@ void paged_attention_v2( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index d11dee91ebe87..2b32ce372a64f 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,14 +3,21 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8_E5M2 +#ifdef ENABLE_FP8 +#ifndef USE_ROCM #include -#endif +#endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) -// fp8 vector types for quantization of kv cache +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache template<> struct Vec { using Type = uint8_t; @@ -30,6 +37,5 @@ template<> struct Vec { using Type = uint2; }; -#endif // ENABLE_FP8_E5M2 } // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 4c142ce17f1b9..8c176c452425e 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,12 +8,12 @@ void swap_blocks( torch::Tensor& src, torch::Tensor& dst, - const std::map& block_mapping); + const torch::Tensor& block_mapping); void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping); + const torch::Tensor& block_mapping); void reshape_and_cache( torch::Tensor& key, @@ -34,5 +34,7 @@ void reshape_and_cache_flash( // Just for unittest void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache); + const float scale, + const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 42f884c76c620..e5b74da6ad068 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,10 +4,11 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#if defined(ENABLE_FP8_E5M2) -#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "quantization/fp8/amd_detail/quant_utils.cuh" + +#ifdef USE_ROCM +#include "quantization/fp8/amd/quant_utils.cuh" +#else +#include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -23,7 +24,7 @@ void swap_blocks( torch::Tensor& src, torch::Tensor& dst, - const std::map& block_mapping) { + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -40,6 +41,11 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + char *src_ptr = static_cast(src.data_ptr()); char *dst_ptr = static_cast(dst.data_ptr()); @@ -47,9 +53,10 @@ void swap_blocks( const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - int64_t dst_block_number = pair.second; + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; cudaMemcpyAsync( @@ -97,7 +104,7 @@ __global__ void copy_blocks_kernel( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping) { + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -114,17 +121,9 @@ void copy_blocks( key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } - // Create block mapping array. - std::vector block_mapping_vec; - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - block_mapping_vec.push_back(src_block_number); - block_mapping_vec.push_back(dst_block_number); - } - } - int64_t* block_mapping_array = block_mapping_vec.data(); - int num_pairs = block_mapping_vec.size() / 2; + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. @@ -132,8 +131,6 @@ void copy_blocks( key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor = torch::from_blob( value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -146,14 +143,14 @@ void copy_blocks( vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] @@ -198,19 +195,12 @@ __global__ void reshape_and_cache_kernel( + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_kv_cache) { -#if defined(ENABLE_FP8_E5M2) - key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); - value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); -#elif defined(ENABLE_FP8_E4M3) - key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); -#else - assert(false); -#endif - } else { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); } } } @@ -252,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel( } } // namespace vllm -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x, \ kv_scale); void reshape_and_cache( @@ -289,25 +282,8 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, float, false); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); - } - } else if (kv_cache_dtype == "fp8") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, uint8_t, true); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( @@ -357,35 +333,34 @@ void reshape_and_cache_flash( namespace vllm { -template +template __global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, + const float kv_scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#if defined(ENABLE_FP8_E5M2) - dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); -#elif defined(ENABLE_FP8_E4M3) - dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); -#else - assert(false); -#endif + dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); } } } // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + kv_scale, \ block_stride); +// Only for testing. void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache) + const float kv_scale, + const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); @@ -403,17 +378,35 @@ void convert_fp8( dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 7849a5df991b1..26e81685d623e 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -8,16 +8,16 @@ template void copy_blocks_cpu_impl( std::vector &key_caches, std::vector &value_caches, - const std::vector> mapping_pairs, + const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { - const size_t pair_num = mapping_pairs.size(); + const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = - element_num_per_block * mapping_pairs[pair].second; + element_num_per_block * mapping_pairs[pair][1].item(); scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t *target_ptr = key_cache_ptr + target_offset; @@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, - const std::map> &block_mapping) { - int num_layers = key_caches.size(); + const torch::Tensor& block_mapping) { + unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; } - std::vector> mapping_pairs; - mapping_pairs.reserve(block_mapping.size()); - for (const auto &pair : block_mapping) { - for (const auto &dst : pair.second) { - mapping_pairs.emplace_back(pair.first, dst); - } - } - const int element_num_per_block = key_caches[0][0].numel(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, element_num_per_block, num_layers); CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) }); @@ -136,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, } void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const std::map &block_mapping) { + const torch::Tensor&block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e9b3992204bb2..5dc1bde45ac5f 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -19,7 +19,6 @@ void rotary_embedding_impl( const int num_tokens) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); - constexpr int ELEM_SIZE = sizeof(scalar_t); const int embed_dim = rot_dim / 2; TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c711d8d1b24b9..1ebb2e74a82fc 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -28,6 +28,12 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index dad8805c750cb..8a3b8403b4a6f 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,8 +1,14 @@ #pragma once #include +#ifndef USE_ROCM #include +#else +#include +#endif +#ifndef USE_ROCM #include +#endif #include #include #include @@ -11,6 +17,24 @@ namespace cg = cooperative_groups; +#ifdef USE_ROCM +template +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} +#endif + +#ifndef USE_ROCM + // nthrs = (32, 4) template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + size_t j = blockIdx.x; + constexpr size_t tile_size = tx * ty * vec_size; + constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; + __shared__ float y_warpwise[ty]; + + float y = 0; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + x_vec.load(X + (batch_idx * feat_in) + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + } + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + + __syncthreads(); + + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + y += sum; + } + } + + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = y; + } + __syncthreads(); + + float y_write = 0.f; +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y_write += y_warpwise[i]; + } + + // write Y; + if (threadIdx.x == 0 && threadIdx.y == 0) { + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); + } +} + +#endif + // nthrs = (2, 16, 4) template @@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif } } @@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, scale); } } else { +#ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); @@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, full_y_size, num_layers, layer_idx, scale); } +#else + constexpr size_t rocm_warp_size = warpSize; + +#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ + feat_in % (rocm_warp_size * vec_size_) == 0 + +#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ + if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ + constexpr size_t vec_size_shrink = vec_size_; \ + constexpr int tx = tx_; \ + constexpr int ty = ty_; \ + dim3 nblks(feat_out, batch_size); \ + dim3 nthrs(tx, ty); \ + bgmv_shrink_kernel \ + <<>>(Y, X, W, indicies, y_offset, \ + full_y_size, num_layers, layer_idx, \ + scale); \ + } + + static_assert(CHECK_INPUT_TILEABLE_BY(32) || + CHECK_INPUT_TILEABLE_BY(16) || + CHECK_INPUT_TILEABLE_BY( 8) || + CHECK_INPUT_TILEABLE_BY( 4) || + CHECK_INPUT_TILEABLE_BY( 2) || + CHECK_INPUT_TILEABLE_BY( 1)); + + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) + +#undef CHECK_INPUT_TILEABLE_BY +#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM +#endif } } diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index cf00d869cf635..2738892e6dc4a 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,8 +1,6 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#include -#include #ifdef FLASHINFER_USE_FP8 #include #endif @@ -10,6 +8,9 @@ #include +#include "../type_convert.h" +#include "../../cuda_compat.h" + #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cu similarity index 98% rename from csrc/punica/punica_ops.cc rename to csrc/punica/punica_ops.cu index 8797fde85744a..61de3b37937cc 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cu @@ -1,12 +1,11 @@ -#include -#include #include #include #include +#include "type_convert.h" +#include "../cuda_compat.h" #include "bgmv/bgmv_config.h" -namespace { //====== utils ====== @@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } - -} // namespace - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..937e2d1d25d4a --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp new file mode 100644 index 0000000000000..9490ad59cdd5f --- /dev/null +++ b/csrc/punica/punica_pybind.cpp @@ -0,0 +1,13 @@ +#include + +#include "punica_ops.h" + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..dff7ce49283d7 --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,82 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifndef USE_ROCM + +#include +#include + +#else + +#include +#include + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { + return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { + return __hip_bfloat162{vall, valr}; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { + return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { + return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { + return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { + return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { + return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { + return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { + return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8.h rename to csrc/quantization/fp8/amd/hip_float8.h diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8_impl.h rename to csrc/quantization/fp8/amd/hip_float8_impl.h diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh similarity index 81% rename from csrc/quantization/fp8/amd_detail/quant_utils.cuh rename to csrc/quantization/fp8/amd/quant_utils.cuh index 894160972d9f4..df0329f79d361 100644 --- a/csrc/quantization/fp8/amd_detail/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,12 +5,17 @@ #include #include +#include "../../../attention/dtype_fp8.cuh" #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" namespace vllm { -namespace fp8_e4m3 { +#ifdef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + template __inline__ __device__ Tout vec_conversion(const Tin& x) { @@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion(const uint3 float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } +#endif // ENABLE_FP8 +template +__inline__ __device__ Tout convert(const Tin &x) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } +#endif + assert(false); } + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // fp8 +#endif // USE_ROCM } // namespace vllm diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/common.cu similarity index 90% rename from csrc/quantization/fp8/fp8_cuda_kernels.cu rename to csrc/quantization/fp8/common.cu index 2477051eb60d7..b9c5d39277ca5 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/common.cu @@ -17,6 +17,15 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } +#define FP8_E4M3_MAX std::numeric_limits::max() + +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { + float x = static_cast(val) / scale; + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + // Compute the absolute maximum m of the input tensor and store // m / float8_e4m3::max() in *scale. Each thread block performs a // reduction tree and the memory in scale is atomically updated. @@ -67,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel( int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { - out[i] = static_cast(input[i] / *scale); + out[i] = scaled_fp8_conversion(input[i], *scale); i += blockDim.x * gridDim.x; } } diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh new file mode 100644 index 0000000000000..4eeacf7a6f9d9 --- /dev/null +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -0,0 +1,568 @@ +#pragma once + +#include "../../../attention/attention_dtypes.h" +#include +#include +#include +#include + +namespace vllm { +#ifndef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + +#if 0 // Disable the following code to reduce the binary size. +template +__inline__ __device__ Tout +vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); + tmp.u32[1] = + vec_conversion((uint16_t)(a >> 16U), fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x, fp8_type); + tmp.u64[1] = vec_conversion(a.y, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); + res.y = + vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float +vec_conversion(const uint8_t &a, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + uint16_t tmp = vec_conversion(a, fp8_type); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = vec_conversion((uint16_t)a, fp8_type); + res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const float &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = vec_conversion(a, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val, fp8_type); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val, fp8_type); + + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint4 b; + b.x = vec_conversion(a.x, fp8_type); + b.y = vec_conversion(a.y, fp8_type); + b.z = vec_conversion(a.z, fp8_type); + b.w = vec_conversion(a.w, fp8_type); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_8_t b; + from_float(b, a); + return b; +} +#endif + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains Convention of the scale in API, e.g: FP8_data = + Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + */ + +template +__inline__ __device__ Tout scaled_vec_conversion( + const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return float_to_half(half_to_float(tmp.x) * scale); +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); + tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + scaled_vec_conversion((uint16_t)a, scale, fp8_type); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), + scale, fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); + tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), + scale, fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale, fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + uint16_t tmp = res.x; + + // half -> float + return half_to_float(tmp) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, + fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, + __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const float &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} +#endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin &x) { +#if 0 // Disable the following code to reduce the binary size. + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return vec_conversion(x, __NV_E5M2); + } +#endif + assert(false); +} + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return scaled_vec_conversion(x, scale, __NV_E5M2); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh deleted file mode 100644 index 9bcab25db03cf..0000000000000 --- a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh +++ /dev/null @@ -1,277 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "../../attention/attention_dtypes.h" -#include "../../attention/dtype_float32.cuh" -#include "../../attention/dtype_float16.cuh" -#include "../../attention/dtype_bfloat16.cuh" - - -namespace vllm { -#ifdef ENABLE_FP8_E5M2 -namespace fp8_e5m2_unscaled { - -template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; -} - -// fp8 -> half -template<> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - return res.x; -} - -// fp8x2 -> half2 -template<> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); - tmp.u16[0] = res.x; - tmp.u16[1] = res.y; - return tmp.u32; -} - -// fp8x4 -> half2x2 -template<> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template<> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; -} - -// fp8 -> __nv_bfloat16 -template<> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - // Note there is no direct convert function from fp8 to bf16. - // fp8 -> half - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - // half -> float -> bf16 - float tmp = half_to_float(res.x); - return __float2bfloat16(tmp); -} - -// fp8x2 -> __nv_bfloat162 -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; -} - -// fp8x4 -> bf16_4_t -template<> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> bf16_8_t -template<> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - -// fp8 -> float -template<> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - // fp8 -> half - uint16_t tmp = vec_conversion(a); - // half -> float - return half_to_float(tmp); -} - -// fp8x2 -> float2 -template<> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ - // fp8x2 -> half2 - uint32_t tmp = vec_conversion(a); - // half2 -> float2 - return half2_to_float2(tmp); -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> float8 -template<> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - - -// half -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; - __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// bf16 -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - assert(false); -#else - __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -#endif -} - -// float -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; -} - - -template<> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; - - float16 = __float22half2_rn(a); - return uint32; -} - -template<> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); - - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - - return b; -} - -template<> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; -} - -template<> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; -} - -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { - __nv_bfloat162 b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { - bf16_4_t b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { - bf16_8_t b; - from_float(b, a); - return b; -} - -} // namespace fp8_e5m2_unscaled -#endif // ENABLE_FP8_E5M2 -} // namespace vllm diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index fd0837f0cb39c..9c6bff000e916 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -115,7 +115,8 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } -// Constructs destination register by taking bytes from 2 sources (based on mask) +// Constructs destination register by taking bytes from 2 sources (based on +// mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; @@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -1275,13 +1276,22 @@ typedef struct { thread_config_t tb_cfg; } exec_config_t; -thread_config_t thread_configs[] = { +thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default (max cache usage) - {64, 128, 128}, // Reduce N, reduce warps - {128, 64, 128}, // Reduce N more, but increase K + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, }; @@ -1397,11 +1407,21 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { - for (auto th_config : thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } @@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } CALL_IF(4, 32, 2, 256) CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) CALL_IF(4, 8, 4, 128) CALL_IF(4, 4, 8, 128) CALL_IF(8, 32, 2, 256) CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) CALL_IF(8, 8, 4, 128) CALL_IF(8, 4, 8, 128) else { diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 0e76763a87b7c..ed569816200ee 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -10,3 +10,4 @@ pydantic torch py-cpuinfo transformers +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst index 067757699f32a..589fce21056c2 100644 --- a/docs/source/models/performance.rst +++ b/docs/source/models/performance.rst @@ -7,7 +7,7 @@ Chunked Prefill --------------- vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. -You can enable the feature by specifying +You can enable the feature by specifying ``--enable-chunked-prefill`` in the command line or setting ``enable_chunked_prefill=True`` in the LLM constructor. .. code-block:: python @@ -16,23 +16,29 @@ You can enable the feature by specifying # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) -By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. +This policy optimizes the TTFT (time to the first token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. -Once chunked prefill is enabled, the policy is changed to +Once chunked prefill is enabled, the policy is changed to prioritize decode requests. +It batches all pending decode requests to the batch before scheduling any prefill. +When there are available token_budget (``max_num_batched_tokens``), it schedules pending prefills. +If a last pending prefill request cannot fit into ``max_num_batched_tokens``, it chunks it. -- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. -- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. +This policy has two benefits: -This policy has two benefits. - -- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. +- It improves ITL and generation decode because decode requests are prioritized. - It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. -You can tune the performance by changing `max_num_batched_tokens`. -By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. -Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. -Higher batch size achieves better TTFT as you can put more prefill to the batch. -If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). -Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. +You can tune the performance by changing ``max_num_batched_tokens``. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark (llama 70B and mixtral 8x22B). +Smaller ``max_num_batched_tokens`` achieves better ITL because there are fewer prefills interrupting decodes. +Higher ``max_num_batched_tokens`` achieves better TTFT as you can put more prefill to the batch. + +- If ``max_num_batched_tokens`` is the same as ``max_model_len``, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +- Note that the default value (512) of ``max_num_batched_tokens`` is optimized for ITL, and it may have lower throughput than the default scheduler. + +We recommend you set ``max_num_batched_tokens > 2048`` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). -See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). +Please try out this feature and let us know your feedback via GitHub issues! \ No newline at end of file diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index c157d8ba998da..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -108,5 +108,5 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args :func: make_arg_parser -:prog: vllm-openai-server +:prog: -m vllm.entrypoints.openai.api_server ``` \ No newline at end of file diff --git a/examples/offline_inference_arctic.py b/examples/offline_inference_arctic.py new file mode 100644 index 0000000000000..1fec3c99eb47c --- /dev/null +++ b/examples/offline_inference_arctic.py @@ -0,0 +1,26 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="snowflake/snowflake-arctic-instruct", + quantization="deepspeedfp", + tensor_parallel_size=8, + trust_remote_code=True) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. + +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json index 5e9bd5bd03869..273f7f5ac42cf 100644 --- a/examples/production_monitoring/grafana.json +++ b/examples/production_monitoring/grafana.json @@ -1,4 +1,41 @@ { + "__inputs": [ + { + "name": "DS_PROMETHEUS", + "label": "prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__elements": {}, + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "10.4.2" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "timeseries", + "name": "Time series", + "version": "" + } + ], "annotations": { "list": [ { @@ -25,14 +62,14 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 29, + "id": null, "links": [], "liveNow": false, "panels": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "End to end request latency measured in seconds.", "fieldConfig": { @@ -41,6 +78,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -54,6 +92,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -111,7 +150,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -127,7 +166,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -144,7 +183,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -161,7 +200,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -178,7 +217,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:e2e_request_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:e2e_request_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -195,7 +234,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of tokens processed per second", "fieldConfig": { @@ -204,6 +243,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -217,6 +257,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -273,7 +314,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -289,7 +330,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -310,7 +351,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Inter token latency in seconds.", "fieldConfig": { @@ -319,6 +360,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -332,6 +374,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -389,7 +432,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -405,7 +448,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -422,7 +465,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -439,7 +482,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -456,7 +499,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -473,7 +516,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of requests in RUNNING, WAITING, and SWAPPED state", "fieldConfig": { @@ -482,6 +525,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -495,6 +539,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -552,7 +597,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -568,7 +613,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -585,7 +630,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -606,7 +651,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "P50, P90, P95, and P99 TTFT latency in seconds.", "fieldConfig": { @@ -615,6 +660,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -628,6 +674,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -685,7 +732,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -702,7 +749,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -718,7 +765,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -735,7 +782,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -752,7 +799,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_to_first_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_to_first_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -769,7 +816,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Percentage of used cache blocks by vLLM.", "fieldConfig": { @@ -778,6 +825,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -791,6 +839,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -848,7 +897,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:gpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -860,7 +909,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:cpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -875,229 +924,232 @@ "type": "timeseries" }, { - "type": "heatmap", - "title": "Request Prompt Length", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, "description": "Heatmap of request prompt length", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, "gridPos": { - "x": 0, - "y": 24, + "h": 8, "w": 12, - "h": 8 - }, - "datasource": { - "uid": "prometheus", - "type": "prometheus" + "x": 0, + "y": 24 }, "id": 12, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "refId": "A", - "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", - "range": true, - "instant": false, - "editorMode": "builder", - "legendFormat": "{{le}}", - "useBackend": false, - "disableTextWrap": false, - "fullMetaSearch": false, - "includeNullMetadata": true, - "format": "heatmap" - } - ], "options": { "calculate": false, - "yAxis": { - "axisPlacement": "left", - "reverse": false, - "unit": "none", - "axisLabel": "Prompt Length" - }, - "rowsFrame": { - "layout": "auto", - "value": "Request count" + "cellGap": 1, + "cellValues": { + "unit": "none" }, "color": { - "mode": "scheme", + "exponent": 0.5, "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, "scale": "exponential", - "exponent": 0.5, "scheme": "Spectral", - "steps": 64, - "reverse": false, - "min": 0 + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" }, - "cellGap": 1, "filterValues": { "le": 1e-9 }, - "tooltip": { - "show": true, - "yHistogram": true - }, "legend": { "show": true }, - "exemplars": { - "color": "rgba(255,0,255,0.7)" + "rowsFrame": { + "layout": "auto", + "value": "Request count" }, - "cellValues": { + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Prompt Length", + "axisPlacement": "left", + "reverse": false, "unit": "none" } }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Prompt Length", + "type": "heatmap" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Heatmap of request generation length", "fieldConfig": { "defaults": { "custom": { - "scaleDistribution": { - "type": "linear" - }, "hideFrom": { + "legend": false, "tooltip": false, - "viz": false, - "legend": false + "viz": false + }, + "scaleDistribution": { + "type": "linear" } } }, "overrides": [] }, - "pluginVersion": "10.2.0" - }, - { - "datasource": { - "uid": "prometheus", - "type": "prometheus" - }, - "type": "heatmap", - "title": "Request Generation Length", - "description": "Heatmap of request generation length", "gridPos": { - "x": 12, - "y": 24, + "h": 8, "w": 12, - "h": 8 + "x": 12, + "y": 24 }, "id": 13, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "refId": "A", - "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", - "range": true, - "instant": false, - "editorMode": "builder", - "legendFormat": "{{le}}", - "useBackend": false, - "disableTextWrap": false, - "fullMetaSearch": false, - "includeNullMetadata": true, - "format": "heatmap" - } - ], "options": { "calculate": false, - "yAxis": { - "axisPlacement": "left", - "reverse": false, - "unit": "none", - "axisLabel": "Generation Length" - }, - "rowsFrame": { - "layout": "auto", - "value": "Request count" + "cellGap": 1, + "cellValues": { + "unit": "none" }, "color": { - "mode": "scheme", + "exponent": 0.5, "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, "scale": "exponential", - "exponent": 0.5, "scheme": "Spectral", - "steps": 64, - "reverse": false, - "min": 0 + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" }, - "cellGap": 1, "filterValues": { "le": 1e-9 }, - "tooltip": { - "show": true, - "yHistogram": true - }, "legend": { "show": true }, - "exemplars": { - "color": "rgba(255,0,255,0.7)" + "rowsFrame": { + "layout": "auto", + "value": "Request count" }, - "cellValues": { + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Generation Length", + "axisPlacement": "left", + "reverse": false, "unit": "none" } }, - "fieldConfig": { - "defaults": { - "custom": { - "scaleDistribution": { - "type": "linear" - }, - "hideFrom": { - "tooltip": false, - "viz": false, - "legend": false - } - } - }, - "overrides": [] - }, - "pluginVersion": "10.2.0" + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Generation Length", + "type": "heatmap" }, { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, + "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", "fieldConfig": { "defaults": { + "color": { + "mode": "palette-classic" + }, "custom": { - "drawStyle": "line", - "lineInterpolation": "linear", + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", "barAlignment": 0, - "lineWidth": 1, + "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", - "spanNulls": false, + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, "insertNulls": false, - "showPoints": "auto", + "lineInterpolation": "linear", + "lineWidth": 1, "pointSize": 5, - "stacking": { - "mode": "none", - "group": "A" - }, - "axisPlacement": "auto", - "axisLabel": "", - "axisColorMode": "text", - "axisBorderShow": false, "scaleDistribution": { "type": "linear" }, - "axisCenteredZero": false, - "hideFrom": { - "tooltip": false, - "viz": false, - "legend": false + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, - "color": { - "mode": "palette-classic" - }, "mappings": [], "thresholds": { "mode": "absolute", @@ -1123,22 +1175,22 @@ }, "id": 11, "options": { - "tooltip": { - "mode": "single", - "sort": "none" - }, "legend": { - "showLegend": true, + "calcs": [], "displayMode": "list", "placement": "bottom", - "calcs": [] + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -1154,25 +1206,19 @@ } ], "title": "Finish Reason", - "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", "type": "timeseries" } ], "refresh": "", - "schemaVersion": 37, - "style": "dark", + "schemaVersion": 39, "tags": [], "templating": { "list": [ { - "current": { - "selected": false, - "text": "vllm", - "value": "vllm" - }, + "current": {}, "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "definition": "label_values(model_name)", "hide": 0, @@ -1201,6 +1247,6 @@ "timezone": "", "title": "vLLM", "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b", - "version": 2, + "version": 1, "weekStart": "" } diff --git a/requirements-common.txt b/requirements-common.txt index 3abb828116680..bd779d5acb68e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,7 +14,7 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer == 0.9.8 +lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6548d7a6684b2..ba8c614d205d2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 diff --git a/setup.py b/setup.py index 3768daf9d6fab..0dc8818b44a9e 100644 --- a/setup.py +++ b/setup.py @@ -355,14 +355,18 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): requirements = _read_requirements("requirements-cuda.txt") - cuda_major = torch.version.cuda.split(".")[0] + cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: if "vllm-nccl-cu12" in req: - modified_requirements.append( - req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}")) - else: - modified_requirements.append(req) + req = req.replace("vllm-nccl-cu12", + f"vllm-nccl-cu{cuda_major}") + elif ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): + # vllm-flash-attn is built only for CUDA 12.1. + # Skip for other versions. + continue + modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("requirements-rocm.txt") @@ -381,12 +385,12 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) - if _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) - if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) + if _install_punica(): + ext_modules.append(CMakeExtension(name="vllm._punica_C")) + package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 64bcba67c3437..55b730812ea94 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -60,13 +60,12 @@ class MockServingChat: tokenizer: MockTokenizer -@pytest.mark.asyncio -async def test_load_chat_template(): +def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template( - mock_serving_chat, chat_template=chatml_jinja_path) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -77,8 +76,7 @@ async def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -@pytest.mark.asyncio -async def test_no_load_chat_template_filelike(): +def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() @@ -86,35 +84,33 @@ async def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) -@pytest.mark.asyncio -async def test_no_load_chat_template_literallike(): +def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template assert template_content == template -@pytest.mark.asyncio @pytest.mark.parametrize( "model,template,add_generation_prompt,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) -async def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): +def test_get_gen_prompt(model, template, add_generation_prompt, + expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d75279dd9cfa9..7d8117447ca0a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -3,9 +3,12 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import os +import weakref import pytest +from vllm import LLM + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -13,6 +16,16 @@ VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" +def test_vllm_gc_ed(): + """Verify vllm instance is GC'ed when it is deleted""" + llm = LLM("facebook/opt-125m") + weak_llm = weakref.ref(llm) + del llm + # If there's any circular reference to vllm, this fails + # because llm instance is not GC'ed. + assert weak_llm() is None + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) diff --git a/tests/conftest.py b/tests/conftest.py index 671326915b22b..1f2ad1cbd7298 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -272,6 +272,68 @@ def generate_greedy_logprobs( all_logprobs.append(seq_logprobs) return all_logprobs + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str]]: + all_logprobs = [] + all_output_ids = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def __del__(self): del self.model cleanup() diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index 3481d6b4312c1..6fb95cfdfab81 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -410,8 +410,7 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, expected_src = static_block_table.physical_block_ids[cow_block_id] expected_dst = appender_block_table.physical_block_ids[cow_block_id] - assert expected_src in cows - assert expected_dst in cows[expected_src] + assert (expected_src, expected_dst) in cows else: # Otherwise, there should be no copy-on-write. assert not cows @@ -490,8 +489,7 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, expected_src = static_block_table.physical_block_ids[cow_block_id] expected_dst = appender_block_table.physical_block_ids[cow_block_id] - assert expected_src in cows - assert expected_dst in cows[expected_src] + assert (expected_src, expected_dst) in cows static_block_table.free() appender_block_table.free() diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 9f9a6180add78..22a9f0cf47d32 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -1,4 +1,5 @@ import time +from collections import defaultdict from typing import List import pytest @@ -141,8 +142,10 @@ def test_append_slot_cow(): child = prompt.fork(new_seq_id=2) # Allocate space for the sequence group. - seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), - time.time(), time.perf_counter) + seq_group = SequenceGroup(request_id="1", + seqs=[prompt, child], + arrival_time=time.time(), + sampling_params=SamplingParams()) block_manager.allocate(seq_group) # Fork and append a new token id. We expect a COW to be scheduled. @@ -155,7 +158,10 @@ def test_append_slot_cow(): cows = block_manager.append_slots(child) assert cows - for src_block, dst_blocks in cows.items(): + dict_cows = defaultdict(list) + for src_block, dst_block in cows: + dict_cows[src_block].append(dst_block) + for src_block, dst_blocks in dict_cows.items(): assert src_block not in dst_blocks after_blocks = block_manager.get_num_free_gpu_blocks() @@ -215,7 +221,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - assert list(mapping.keys()) == gpu_blocks + assert [x[0] for x in mapping] == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) @@ -228,7 +234,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) - assert list(mapping.keys()) == cpu_blocks + assert [x[0] for x in mapping] == cpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks @@ -299,8 +305,11 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks parent = Sequence(1, "one two three", [0, 1, 2], block_size) - seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), - None) + seq_group = SequenceGroup(request_id="1", + seqs=[parent], + arrival_time=time.time(), + sampling_params=SamplingParams(), + lora_request=None) block_manager.allocate(seq_group) # assert the number of blocks allocated is correct diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 92498c0014666..3649e6b003a5d 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -355,8 +355,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap should be prioritized over new prefill. _, seq_group = create_dummy_prompt("2", prompt_length=60) @@ -365,8 +365,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_running_prefill_prioritized_over_swap(): @@ -406,8 +406,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() @@ -419,8 +419,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. @@ -429,8 +429,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -440,8 +440,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 1 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -451,8 +451,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_chunked_prefill_preempt(): @@ -493,8 +493,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == [] + assert out.blocks_to_swap_in == [] # Make sure we can reschedule preempted request. _, out = schedule_and_update_computed_tokens(scheduler) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 1358dffec8104..6bcabc4f95fa9 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -293,8 +293,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 2 assert out.num_batched_tokens == 2 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. @@ -305,8 +305,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 3 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 3 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def initialize_scheduler(*, @@ -566,9 +566,9 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # NOTE: When enable_chunk is False, num_seqs budget is not updated. # assert budget.num_curr_seqs == 1 # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_decode_swap_beam_search(): @@ -599,7 +599,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): scheduler.block_manager.can_append_slots.side_effect = ( cannot_append_second_group) scheduler.block_manager.swap_out = MagicMock() - expected_swap_mapping = {"5": "7"} + expected_swap_mapping = [("5", "7")] scheduler.block_manager.swap_out.return_value = expected_swap_mapping remainig_running, output = scheduler._schedule_running( @@ -618,7 +618,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == expected_swap_mapping # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_schedule_decode_blocks_to_copy_update(): @@ -636,7 +636,7 @@ def test_schedule_decode_blocks_to_copy_update(): # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = {2: [3]} + scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() remaining_running, output = scheduler._schedule_running( @@ -647,10 +647,10 @@ def test_schedule_decode_blocks_to_copy_update(): assert len(output.preempted) == 0 assert len(output.swapped_out) == 0 # Nothing is preempted. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Since append_slot returns the source -> dist mapping, it should # applied. - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_schedule_swapped_simple(): @@ -658,7 +658,7 @@ def test_schedule_swapped_simple(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -674,9 +674,9 @@ def test_schedule_swapped_simple(): assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 # swap in is the reverse of swap out - blocks_to_swap_in_reverse = {} - for swapin, swapout in output.blocks_to_swap_in.items(): - blocks_to_swap_in_reverse[swapout] = swapin + blocks_to_swap_in_reverse = [] + for swapin, swapout in output.blocks_to_swap_in: + blocks_to_swap_in_reverse.append((swapout, swapin)) assert blocks_to_swap_out == blocks_to_swap_in_reverse @@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) scheduler._allocate_and_set_running(seq_group) @@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = set() - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -808,7 +808,7 @@ def test_infeasible_swap(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -839,13 +839,13 @@ def test_schedule_swapped_blocks_to_copy(): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out = {} + blocks_to_swap_out = [] scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = {2: [3]} + scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( @@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy(): assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_scheduling_budget(): diff --git a/tests/core/utils.py b/tests/core/utils.py index 22c1d3826dff4..8fb13177a2d6c 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -22,10 +22,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) + seq_group = SequenceGroup(request_id=request_id, + seqs=[prompt], + arrival_time=time.time(), + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + lora_request=lora_request) return prompt, seq_group diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index aa9e0537c6910..9a7a1f07e1b8d 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) test_dict = { + # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), - "b": torch.arange(16, dtype=torch.int8, device="cuda"), + # CPU tensor + "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], "e": { "a": 1, "b": 2 }, + # empty tensor + "f": torch.tensor([], dtype=torch.float32, device="cuda"), } if rank == 0: @@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, assert recv_dict["c"] == test_dict["c"] assert recv_dict["d"] == test_dict["d"] assert recv_dict["e"] == test_dict["e"] + assert torch.allclose(recv_dict["f"], test_dict["f"]) @pytest.mark.skipif(torch.cuda.device_count() < 2, diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 3b1cd1773af19..308b874280f55 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: with custom_all_reduce.capture(): @@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b6f461b76ed03..b3e30a0434423 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,15 +1,15 @@ import multiprocessing +import os import pytest import torch -import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetUniqueId) -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, - init_distributed_environment, with_pynccl_for_all_reduce) +from vllm.distributed.communication_op import ( # noqa + graph_capture_mode, tensor_model_parallel_all_reduce) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import update_environment_variables @@ -41,6 +41,9 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) init_distributed_environment() fn() @@ -49,11 +52,13 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) + pynccl_comm = PyNcclCommunicator() + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() - assert result == comm.world_size + assert result == pynccl_comm.world_size @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -70,37 +75,35 @@ def multiple_tp_worker_fn(): torch.distributed.new_group(ranks=[2, 3], backend="gloo") ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] - comm = NCCLCommunicator(group=group, device=device) + pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - comm.all_reduce(tensor) - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 - else: - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): # this tests pynccl for multiple tp groups, in a standalone way - # i.e. call `comm.all_reduce` directly + # i.e. call `pynccl_comm.all_reduce` directly distributed_run(multiple_tp_worker_fn, 4) @worker_fn_wrapper def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - torch.cuda.set_device(torch.distributed.get_rank()) ensure_model_parallel_initialized(2, 2) - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with with_pynccl_for_all_reduce(): + with graph_capture_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) @@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + pynccl_comm = PyNcclCommunicator() # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + pynccl_comm.all_reduce(a) + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph(): def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() # `list(unique_id.internal)` is something like this: # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 269b0823fec05..13e2e372cef33 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -20,11 +20,15 @@ class MockModelConfig: class MockEngine: async def get_model_config(self): - return MockModelConfig + return MockModelConfig() async def _async_serving_chat_init(): - serving_completion = OpenAIServingChat(MockEngine(), + engine = MockEngine() + model_config = await engine.get_model_config() + + serving_completion = OpenAIServingChat(engine, + model_config, served_model_names=[MODEL_NAME], response_role="assistant", chat_template=CHAT_TEMPLATE) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 84539205e0ae3..28496f187d466 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -236,14 +236,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index ca215bb75837a..9f0cb60dc16e2 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,8 +5,6 @@ import torch from vllm import _custom_ops as ops -from vllm._C import cache_ops -from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,6 +23,8 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + +# We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -63,12 +63,13 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} + block_mapping = [] for i in range(num_mappings): src = src_blocks[i] dst1 = dst_blocks[2 * i] dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -81,15 +82,17 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - ops.copy_blocks(key_caches, value_caches, block_mapping) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device=device).view(-1, 2) + ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dst in block_mapping: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): @@ -121,8 +124,6 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -146,9 +147,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(cloned_key_cache, key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(cloned_value_cache, value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -162,9 +163,9 @@ def test_reshape_and_cache( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(result_key_cache, key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -219,11 +220,12 @@ def test_reshape_and_cache_flash( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) qkv = torch.randn(num_tokens, 3, @@ -242,6 +244,7 @@ def test_reshape_and_cache_flash( head_size, kv_cache_dtype, dtype, + device=device, ) key_cache, value_cache = key_caches[0], value_caches[0] @@ -250,8 +253,8 @@ def test_reshape_and_cache_flash( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') @@ -294,8 +297,6 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -312,7 +313,10 @@ def test_swap_blocks( else: dst_blocks = random.sample(range(num_blocks), num_mappings) - block_mapping = dict(zip(src_blocks, dst_blocks)) + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device="cpu").view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( @@ -328,17 +332,18 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], + block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], + block_mapping_tensor) - for src, dst in block_mapping.items(): + for src, dst in block_mapping: assert torch.allclose(src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) -@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -347,7 +352,7 @@ def test_swap_blocks( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_fp8_conversion( +def test_fp8_e4m3_conversion( num_heads: int, head_size: int, block_size: int, @@ -367,9 +372,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache_fp8, cache) converted_cache = torch.empty_like(cache) - ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(converted_cache, cache_fp8) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 5a5987e2242fa..99fda8364dc0e 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,3 +1,4 @@ +import math import random import time @@ -6,11 +7,12 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96] +HEAD_SIZES = [128, 96, 24] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -207,3 +209,242 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE(DefTruth): In order to reuse _make_alibi_bias function, + # we have to pad query tensor before MQA/GQA expanding. + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) + seq_start += seq_len + query_start += query_len + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 3dde498bcd639..c02204f16ac68 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -8,7 +8,7 @@ MODELS = [ "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", # Broken + # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 7aeff3a913098..33d28da85d9e7 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -4,6 +4,8 @@ """ import pytest +from tests.models.utils import check_logprobs_close + MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", ] @@ -11,30 +13,31 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.skip( - "Two problems: 1. Failing correctness tests. 2. RuntimeError: expected " - "scalar type BFloat16 but found Half (only in CI).") +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, - example_long_prompts, + example_prompts, model: str, dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: + # TODO(sang): Sliding window should be tested separately. hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) del hf_model vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) del vllm_model - - for i in range(len(example_long_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 13b5b80cccfdc..00a2379502e6d 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -42,9 +42,11 @@ def mock_causal_accepted_tensor( @pytest.mark.parametrize( "which_tokens_accepted", ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int, +def test_correct_output_format(which_tokens_accepted: str, + disable_bonus_tokens: bool, seed: int, device: str): """Verify the output has correct format given predetermined accepted matrix. """ @@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, size=(batch_size, 1), dtype=torch.int64) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens) rejection_sampler.init_gpu_tensors(rank=0) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) - # Bonus tokens are currently disabled. Verify they're set to -1. + expected_bonus_token_ids = bonus_token_ids.clone() + # If bonus tokens disabled. Verify they are set to -1. # See https://github.com/vllm-project/vllm/issues/4212 - expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 + if disable_bonus_tokens: + expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1 if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index e4fea165a4d46..ddc66aa28a094 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter -from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter, is_pin_memory_available class MockLogitsSampler(Sampler): @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, sampler, model_runner + return input_tensor, fake_logits, sampler VOCAB_SIZE = 32000 @@ -53,7 +46,6 @@ def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, - model_runner: ModelRunner, sampling_params: SamplingParams, device: str, ): @@ -75,7 +67,7 @@ def _do_sample( seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -105,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -115,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -131,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -141,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -157,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, @@ -165,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) assert first_sampler_output == second_sampler_output - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -181,20 +163,18 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, - device) + _do_sample(batch_size, fake_logits, sampler, sampling_params, device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. - del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -448,13 +428,13 @@ def run_test_case(*, ("Invalid test case, expected_penalization does not match computed" "batch size") - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else None, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -480,8 +460,6 @@ def run_test_case(*, fake_logits[logits_idx, :] == -float('inf')) == 0, "No tokens should have been penalized" - del model_runner - for test_case in test_cases: run_test_case(**test_case) @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - def test_sampling(model_runner: ModelRunner): + def test_sampling(): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner): assert nth_output.output_token in expected_tokens[i] # Test batch - test_sampling(model_runner) + test_sampling() # Shuffle the batch and resample target_index = list(range(batch_size)) @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner): # This time, results of seeded random samples will be compared with # the corresponding sample in the pre-shuffled batch - test_sampling(model_runner) - - del model_runner + test_sampling() @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sample_probs = None @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - del model_runner diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b1ab8a07ca636..eda7293ea7cee 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -55,7 +55,7 @@ def __init__( ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - self.engine_args = AsyncEngineArgs( + engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, @@ -76,6 +76,8 @@ def __init__( **kwargs, ) self.request_counter = Counter() + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) def generate( self, @@ -88,9 +90,6 @@ def generate( multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: - llm_engine = AsyncLLMEngine.from_engine_args( - self.engine_args, usage_context=UsageContext.LLM_CLASS) - if prompts is None: raise ValueError("prompts must be provided.") if isinstance(prompts, str): @@ -111,8 +110,8 @@ def generate( async def get_output(prompt, sampling_param) -> str: request_id = random_uuid() - results_generator = llm_engine.generate(prompt, sampling_param, - request_id) + results_generator = self.llm_engine.generate( + prompt, sampling_param, request_id) final_output = None async for request_output in results_generator: final_output = request_output @@ -185,12 +184,25 @@ def generator_outer(): return generator_outer +def maybe_assert_ngram_worker(llm): + # Verify the proposer worker is ngram if ngram is specified. + if (not isinstance(llm, AsyncLLM) + and llm.llm_engine.speculative_config is not None + and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): + from vllm.spec_decode.ngram_worker import NGramWorker + assert isinstance( + llm.llm_engine.model_executor.driver_worker.proposer_worker, + NGramWorker) + + def get_output_from_llm_generator( llm_generator, prompts, sampling_params) -> Tuple[List[str], List[List[int]]]: tokens = [] token_ids = [] for llm in llm_generator(): + maybe_assert_ngram_worker(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index f15fcc4746d20..d2da039e84c07 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + }, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_disable_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when all sequences disable speculation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -577,3 +611,40 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 44ef400c91d34..c2004ff061a1e 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -57,7 +57,7 @@ @pytest.mark.parametrize("output_len", [ 256, ]) -@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) def test_ngram_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, batch_size: int, diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py new file mode 100644 index 0000000000000..948a74b22f0ae --- /dev/null +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_batch, mock_worker + + +@pytest.mark.parametrize('queue_size', [2, 4]) +@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) +@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@torch.inference_mode() +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): + """Verify that speculative tokens are disabled when the batch size + exceeds the threshold. + """ + disable_by_batch_size = 3 + + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=rejection_sampler, + metrics_collector=metrics_collector, + disable_by_batch_size=disable_by_batch_size) + + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + running_queue_size=queue_size) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + + # When the batch size is larger than the threshold, + # we expect no speculative tokens (0). + expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 + assert seq_group_metadata_list[ + 0].num_speculative_tokens == expected_num_spec_tokens + + draft_worker.sampler_output.side_effect = ValueError(exception_secret) + + proposer = Top1Proposer( + worker=draft_worker, + device='cpu', # not used + vocab_size=100, # not used + # Must be long enough to avoid being skipped due to length. + max_proposal_len=1024, + ) + + if queue_size < disable_by_batch_size: + # Should raise exception when executing the mocked draft model. + with pytest.raises(ValueError, match=exception_secret): + proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + else: + # Should not execute the draft model because spec decode is disabled + # for all requests. Accordingly, the proposal length should be 0. + proposals = proposer.get_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 179e8d25a341b..4ee980505a3ab 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -9,7 +9,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.utils import is_pin_memory_available class MockLogitsProcessor(LogitsProcessor): @@ -30,21 +30,15 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, logits_processor, model_runner + return input_tensor, fake_logits, logits_processor RANDOM_SEEDS = list(range(128)) @@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -87,8 +80,8 @@ def pick_ith(token_ids, logits): seq_group_metadata_list, seq_lens, query_lens=seq_lens, - device=model_runner.device, - pin_memory=model_runner.pin_memory) + device=device, + pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, @@ -99,5 +92,3 @@ def pick_ith(token_ids, logits): fake_logits *= logits_processor.scale assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], 1e-4) - - del model_runner diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b16bdc141e57c..53061278d5be4 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,36 +1,8 @@ -import time -from typing import Optional - import pytest -from vllm import SamplingParams -from vllm.lora.request import LoRARequest -from vllm.sequence import (SamplerOutput, Sequence, SequenceData, - SequenceGroup, SequenceGroupOutput, SequenceOutput) - - -def create_dummy_prompt( - request_id: str, - prompt_length: int, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, - best_of: int = 1, -) -> SequenceGroup: - if not block_size: - block_size = prompt_length - - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". - prompt_tokens = list(range(prompt_length)) - prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) - - return seq_group +from tests.core.utils import create_dummy_prompt +from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) @pytest.fixture @@ -102,7 +74,7 @@ def test_sequence_data_prefill(): def test_sequence_group_stage(): - seq_group = create_dummy_prompt("1", 12) + _, seq_group = create_dummy_prompt("1", 12) assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(6) assert seq_group.is_prefill() is True diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e7975d0ef48b9..3e3d2e3f5c53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,27 +1,38 @@ import pytest import torch -from vllm.config import ModelConfig, SchedulerConfig from vllm.distributed.parallel_state import init_distributed_environment +from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) + model_runner = _create_model_runner( + "facebook/opt-125m", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + ) seq_lens = [] seq_group_metadata_list = [] @@ -123,27 +134,15 @@ def test_prepare_prompt(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_decode_cuda_graph(batch_size): - model_config = ModelConfig( + model_runner = _create_model_runner( "facebook/opt-125m", - "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_lens = [] seq_group_metadata_list = [] @@ -214,23 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size): def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output.""" - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_group_metadata_list = [] input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) @@ -260,29 +248,15 @@ def distributed_init(): @pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_hybrid_batches(batch_size, enforce_eager, distributed_init): - - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=enforce_eager, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=True, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=True) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None, - is_driver_worker=True) - model_runner.set_block_size(16) # Add prefill requests. seq_lens = [] diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 07bcd343a96a6..d941ffdb5588a 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -54,36 +54,36 @@ def test_swap() -> None: a.cuda(), b.cuda(), rtol=0.0, atol=0.0) # Test swap out. - blocks_to_swap_out = {3: 72, 56: 35, 84: 34} + blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] execute_model_req = ExecuteModelRequest( seq_group_metadata_list=[], - blocks_to_swap_in={}, + blocks_to_swap_in=[], blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}, + blocks_to_copy=[], ) worker.execute_model(execute_model_req=execute_model_req) for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_out.items(): + for src, dst in blocks_to_swap_out: assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - execute_model_req.blocks_to_swap_out = {} - execute_model_req.blocks_to_swap_in = { - 19: 45, - 67: 23, - 12: 78, - 40: 99, - 1: 71 - } + execute_model_req.blocks_to_swap_out = [] + execute_model_req.blocks_to_swap_in = [ + (19, 45), + (67, 23), + (12, 78), + (40, 99), + (1, 71), + ] worker.execute_model(execute_model_req=execute_model_req) for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in execute_model_req.blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in: assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5b56437487477..42dedfdf76c4f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch @@ -189,8 +189,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) @@ -240,12 +266,15 @@ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: Dict[int, int]) -> None: + block_mapping: torch.Tensor) -> None: vllm_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: - vllm_cache_ops.convert_fp8(output, input) +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) #TODO: cuda_utils, custom_ar diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 61c9c81d8a7b8..64ccb309a0480 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,11 @@ class AttentionBackend(ABC): """Abstract class for attention backends.""" + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: @@ -34,7 +39,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError @@ -42,7 +47,7 @@ def swap_blocks( @abstractmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index fc7501ed5e91f..4bad226512b69 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,10 +5,10 @@ flashinfer for all the attention operations. """ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, @@ -19,6 +19,10 @@ class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flash-attn" + @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @@ -41,14 +45,14 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8ab4b1f12ee36..36e162671f944 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,16 +1,10 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -try: - import flashinfer - from flash_attn import flash_attn_varlen_func - from flashinfer import BatchDecodeWithPagedKVCacheWrapper -except ImportError: - flashinfer = None - flash_attn_varlen_func = None - BatchDecodeWithPagedKVCacheWrapper = None - +import flashinfer import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -20,6 +14,10 @@ class FlashInferBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flashinfer" + @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl @@ -41,14 +39,14 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c411b3971b8f1..8fc1af1aa1e1c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -17,6 +17,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl @@ -39,14 +43,14 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index f75a279086a26..c29218dfd0cfc 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,7 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention @@ -15,6 +15,10 @@ class TorchSDPABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "torch-sdpa" + @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @@ -37,14 +41,14 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 60f6d43f2eaa4..2a9150dea5875 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -20,6 +20,10 @@ class XFormersBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "xformers" + @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl @@ -49,7 +53,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 00a0f10c0950b..3c010b67b3120 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -196,7 +196,7 @@ def forward_prefix( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] @@ -209,7 +209,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 79878b26c5294..997b25e887e30 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -472,7 +472,8 @@ def _fwd_kernel_alibi( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): # attn_bias[] @@ -493,21 +494,24 @@ def _fwd_kernel_alibi( # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) alibi_slope = tl.load(Alibi_slopes + cur_head) alibi_start_q = tl.arange( @@ -532,8 +536,9 @@ def _fwd_kernel_alibi( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -567,7 +572,8 @@ def _fwd_kernel_alibi( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -600,8 +606,9 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -637,8 +644,9 @@ def _fwd_kernel_alibi( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -656,7 +664,8 @@ def _fwd_kernel_alibi( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) return @torch.inference_mode() @@ -690,7 +699,6 @@ def context_attention_fwd(q, num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: - assert Lk == Lk_padded _fwd_kernel_alibi[grid]( q, k, @@ -735,6 +743,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 34da0f6c6cdfc..f4446bac6b8d2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS try: - import flash_attn # noqa: F401 + import vllm_flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention-2 backend because the flash_attn " - "package is not found. Please install it for better performance.") + "Cannot use FlashAttention-2 backend because the vllm_flash_attn " + "package is not found. `pip install vllm-flash-attn` for better " + "performance.") return _Backend.XFORMERS backend_by_env_var = envs.VLLM_ATTENTION_BACKEND diff --git a/vllm/config.py b/vllm/config.py index 19847192d8afa..600c53ecd82c7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -699,6 +699,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: @@ -727,6 +728,9 @@ def maybe_create_spec_config( use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + speculative_disable_by_batch_size (Optional[int]): Disable + speculative decoding for new incoming requests when the number + of enqueue requests is larger than this value, if provided. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token @@ -737,7 +741,7 @@ def maybe_create_spec_config( the necessary conditions are met, else None. """ - if (speculative_model is None and num_speculative_tokens is None): + if speculative_model is None and num_speculative_tokens is None: return None if speculative_model is not None and num_speculative_tokens is None: @@ -746,6 +750,12 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + if (speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}") + assert (speculative_model is not None and num_speculative_tokens is not None) @@ -814,6 +824,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, ) @@ -883,8 +894,9 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, - ngram_prompt_lookup_max: int, - ngram_prompt_lookup_min: int, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ): """Create a SpeculativeConfig object. @@ -893,12 +905,19 @@ def __init__( draft_parallel_config: ParallelConfig for the draft model. num_speculative_tokens: The number of tokens to sample from the draft model before scoring with the target model. + speculative_disable_by_batch_size: Disable speculative + decoding for new incoming requests when the number of + enqueue requests is larger than this value. + ngram_prompt_lookup_max: Max size of ngram token window. + ngram_prompt_lookup_min: Min size of ngram token window. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + self.speculative_disable_by_batch_size = \ + speculative_disable_by_batch_size + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self._verify_args() @@ -1051,6 +1070,7 @@ def _get_and_verify_dtype( if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 # models. + logger.info("Casting torch.float32 to torch.float16.") torch_dtype = torch.float16 else: torch_dtype = config_dtype @@ -1075,9 +1095,11 @@ def _get_and_verify_dtype( if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) pass else: # Casting between float16 and bfloat16 is allowed with a warning. diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 3f97a1210b096..4d7a12165cb01 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,5 +1,4 @@ -from collections import defaultdict -from typing import Dict, Iterable, List, Optional, Protocol +from typing import Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -111,7 +110,7 @@ def __init__( refcounter: RefCounterProtocol, allocator: BlockAllocator, ): - self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] self._refcounter = refcounter self._allocator = allocator @@ -152,25 +151,25 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: # Track src/dst copy. assert src_block_id is not None assert block_id is not None - self._copy_on_writes[src_block_id].append(block_id) + self._copy_on_writes.append((src_block_id, block_id)) return block_id - def clear_cows(self) -> Dict[BlockId, List[BlockId]]: + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: """Clears the copy-on-write tracking information and returns the current state. - This method returns a dictionary mapping source block indices to lists - of destination block indices for the current copy-on-write operations. + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. It then clears the internal tracking information. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices for the + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the current copy-on-write operations. """ - cows = dict(self._copy_on_writes) - self._copy_on_writes.clear() + cows = self._copy_on_writes + self._copy_on_writes = [] return cows diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5b25e1bcdada0..0577ca76ea971 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -1,4 +1,4 @@ -from typing import Dict, FrozenSet, List, Optional +from typing import Dict, FrozenSet, List, Optional, Tuple from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) @@ -185,13 +185,13 @@ def get_num_free_blocks(self, device: Device) -> int: def get_num_total_blocks(self, device: Device) -> int: return self._allocators[device].get_num_total_blocks() - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. Returns: - Dict[int, List[int]]: A dictionary mapping source block IDs to lists - of destination block IDs. + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. """ # CoW only supported on GPU device = Device.GPU diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 634c4016ca19c..140fbbb0949cc 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, FrozenSet, List, Optional, Protocol +from typing import FrozenSet, List, Optional, Protocol, Tuple from vllm.utils import Device @@ -122,7 +122,7 @@ def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: pass @abstractmethod @@ -187,7 +187,7 @@ def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index a1b901bf78efc..ae01930878254 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,4 +1,4 @@ -from typing import Dict, FrozenSet, Iterable, List, Optional, Set +from typing import FrozenSet, Iterable, List, Optional, Set, Tuple from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) @@ -175,12 +175,12 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: """ return self._cow_tracker.cow_block_if_not_appendable(block) - def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices. + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. """ return self._cow_tracker.clear_cows() diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 4a37e8f87c379..882f301c1f697 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,7 +1,7 @@ """Token blocks.""" from itertools import takewhile from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional +from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) @@ -337,12 +337,12 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: """ return self._cow_tracker.cow_block_if_not_appendable(block) - def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices. + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. """ return self._cow_tracker.clear_cows() diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 268c5c135d887..52a170d79e4e7 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -5,7 +5,7 @@ from os.path import commonprefix from typing import Dict, List, Optional from typing import Sequence as GenericSequence -from typing import Set +from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor @@ -386,7 +386,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int = 0, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] @@ -405,7 +405,7 @@ def append_slots( # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) - return {} + return [] # We want to append the token to the last physical block. last_block = block_table[-1] @@ -418,7 +418,7 @@ def append_slots( maybe_new_block = self._maybe_promote_last_block( seq, last_block) block_table[-1] = maybe_new_block - return {} + return [] else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. @@ -426,7 +426,7 @@ def append_slots( block_table[-1] = new_block self.gpu_allocator.free(last_block) - return {last_block.block_number: [new_block.block_number]} + return [(last_block.block_number, new_block.block_number)] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. @@ -473,11 +473,12 @@ def can_swap_in(self, def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> Dict[int, int]: + num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" # CPU block -> GPU block. + # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] @@ -500,14 +501,16 @@ def swap_in(self, cpu_block.block_number: gpu_block.block_number for cpu_block, gpu_block in mapping.items() } - return block_number_mapping + # convert to list of tuples once here + return list(block_number_mapping.items()) def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: # GPU block -> CPU block. + # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): new_block_table: BlockTable = [] @@ -530,7 +533,8 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: gpu_block.block_number: cpu_block.block_number for gpu_block, cpu_block in mapping.items() } - return block_number_mapping + # convert to list of tuples once here + return list(block_number_mapping.items()) def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index ce90ce2f17278..f0bc96564050a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,6 +1,7 @@ """A block manager that manages token blocks.""" from typing import Dict, List, Optional from typing import Sequence as GenericSequence +from typing import Tuple from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator @@ -166,7 +167,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] @@ -242,13 +243,13 @@ def can_swap_in(self, seq_group: SequenceGroup, return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: raise NotImplementedError def can_swap_out(self, seq_group: SequenceGroup) -> bool: return False - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: raise NotImplementedError def get_num_free_gpu_blocks(self) -> int: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 09ccaddb62615..b2a5e41990f39 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,7 +1,8 @@ import enum from abc import ABC, abstractmethod -from typing import Dict, List +from typing import List from typing import Sequence as GenericSequence +from typing import Tuple from vllm.sequence import Sequence, SequenceGroup @@ -54,7 +55,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: pass @abstractmethod @@ -68,7 +69,7 @@ def can_swap_in(self, seq_group: SequenceGroup, @abstractmethod def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: pass @abstractmethod @@ -76,7 +77,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: pass @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a9e0b05b8db67..35e3db18f1c43 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.utils import merge_dicts logger = init_logger(__name__) @@ -118,12 +117,12 @@ class SchedulerOutputs: num_prefill_groups: int # Total number of batched tokens. num_batched_tokens: int - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. @@ -175,9 +174,9 @@ class SchedulerRunningOutputs: # Sequences that are swapped out. swapped_out: List[SequenceGroup] # The blocks to swap out. - blocks_to_swap_out: Dict[int, int] + blocks_to_swap_out: List[Tuple[int, int]] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int @@ -188,8 +187,8 @@ def create_empty(cls) -> "SchedulerRunningOutputs": prefill_seq_groups=[], preempted=[], swapped_out=[], - blocks_to_swap_out={}, - blocks_to_copy={}, + blocks_to_swap_out=[], + blocks_to_copy=[], num_lookahead_slots=0, ) @@ -207,9 +206,9 @@ class SchedulerSwappedInOutputs: # phase. I.e., it means the prefill has been chunked. prefill_seq_groups: List[SequenceGroup] # The blocks to swap in. - blocks_to_swap_in: Dict[int, int] + blocks_to_swap_in: List[Tuple[int, int]] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int # Infeasible sequence groups. @@ -220,8 +219,8 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs": return SchedulerSwappedInOutputs( decode_seq_groups=[], prefill_seq_groups=[], - blocks_to_swap_in={}, - blocks_to_copy={}, + blocks_to_swap_in=[], + blocks_to_copy=[], num_lookahead_slots=0, infeasible_seq_groups=[], ) @@ -393,8 +392,8 @@ def _schedule_running( scheduling and SchedulerRunningOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_swap_out: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -510,8 +509,8 @@ def _schedule_swapped( SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() @@ -794,8 +793,8 @@ def _schedule_default(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -882,8 +881,8 @@ def _schedule_chunked_prefill(self): num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), @@ -1011,32 +1010,29 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _append_slots( self, seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: List[Tuple[int, int]], ) -> None: """Appends new slots to the sequences in the given sequence group. Args: seq_group (SequenceGroup): The sequence group containing the sequences to append slots to. - blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source - block indices to lists of destination block indices. This - dictionary is updated with the new source and destination block - indices for the appended slots. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) - - for src, dests in cows.items(): - if src not in blocks_to_copy: - blocks_to_copy[src] = [] - blocks_to_copy[src].extend(dests) + blocks_to_copy.extend(cows) def _preempt( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], preemption_mode: Optional[PreemptionMode] = None, ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: @@ -1077,24 +1073,24 @@ def _preempt_by_recompute( def _preempt_by_swap( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) def _swap_in( self, seq_group: SequenceGroup, - blocks_to_swap_in: Dict[int, int], + blocks_to_swap_in: List[Tuple[int, int]], ) -> None: mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.update(mapping) + blocks_to_swap_in.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): seq.status = SequenceStatus.RUNNING def _swap_out( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: if not self.block_manager.can_swap_out(seq_group): # FIXME(woosuk): Abort the sequence group instead of aborting the @@ -1103,7 +1099,7 @@ def _swap_out( "Aborted due to the lack of CPU swap space. Please increase " "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.update(mapping) + blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index b539a7beedbfe..32ab5694e5390 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,4 +1,5 @@ from collections import namedtuple +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -8,7 +9,26 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - is_pynccl_enabled_for_all_reduce) + get_tp_pynccl_communicator) + + +@contextmanager +def graph_capture_mode(): + # In graph capture, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor size + # is too large, it will fallback to the next available option. + pynccl_comm = get_tp_pynccl_communicator() + assert pynccl_comm is not None + with pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()): + yield def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) @@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_pynccl_enabled_for_all_reduce(): - pynccl_utils.all_reduce(input_) + pynccl_comm = get_tp_pynccl_communicator() + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) @@ -137,7 +157,7 @@ def broadcast_object_list(obj_list: List[Any], return obj_list -TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) def _split_tensor_dict( @@ -152,15 +172,13 @@ def _split_tensor_dict( tensor_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): - # Note(youkaichao): currently this only supports broadcasting - # tensors on cuda. In the future, we can add device as a field in - # TensorMetadata to support broadcasting tensors on different - # devices. - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append((key, TensorMetadata(value.dtype, - value.size()))) + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = "cpu" if value.is_cpu else "cuda" + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -203,11 +221,22 @@ def broadcast_tensor_dict( group=metadata_group) async_handles = [] for tensor in tensor_list: - async_handles.append( - torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True)) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -223,12 +252,24 @@ def broadcast_tensor_dict( if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, - device="cuda") - async_handle = torch.distributed.broadcast(tensor, - src=src, - async_op=True, - group=group) - async_handles.append(async_handle) + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) tensor_dict[key] = tensor else: tensor_dict[key] = value diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index cc5f8166877ce..5d26254fb832a 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -52,6 +52,10 @@ def init_custom_ar() -> None: "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" " is set.") return + + # we only use a subset of GPUs here + # so we only need to check the nvlink connectivity of these GPUs + num_dev = world_size # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 758994352e3de..168d4cc2df8a6 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,26 +1,4 @@ -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import platform +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -28,217 +6,70 @@ import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check logger = init_logger(__name__) -so_file = find_nccl_library() - -try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) - nccl = ctypes.CDLL(so_file) -except Exception as e: - logger.error( - "Failed to load NCCL library from %s ." - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) - raise e - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int - -_c_ncclGetErrorString = nccl.ncclGetErrorString -_c_ncclGetErrorString.restype = ctypes.c_char_p -_c_ncclGetErrorString.argtypes = [ncclResult_t] - - -def NCCL_CHECK(result: ncclResult_t) -> None: - if result != 0: - error_str = _c_ncclGetErrorString(result) - error_str = error_str.decode("utf-8") - raise RuntimeError(f"NCCL error: {error_str}") - - -# equivalent to c declaration: -# ncclResult_t ncclGetVersion(int *version); -_c_ncclGetVersion = nccl.ncclGetVersion -_c_ncclGetVersion.restype = ctypes.c_int -_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - - -def ncclGetVersion() -> str: - version = ctypes.c_int() - NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) - # something like 21903 --> "2.19.3" - version_str = str(version.value) - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - -class NcclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -# equivalent to c declaration: -# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -_c_ncclGetUniqueId = nccl.ncclGetUniqueId -_c_ncclGetUniqueId.restype = ctypes.c_int -_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] - - -def ncclGetUniqueId() -> NcclUniqueId: - unique_id = NcclUniqueId() - NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) - return unique_id - - -# equivalent to c declaration: -# ncclResult_t ncclCommInitRank( -# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -# note that ncclComm_t is a pointer type, so the first argument -# is a pointer to a pointer -_c_ncclCommInitRank = nccl.ncclCommInitRank -_c_ncclCommInitRank.restype = ctypes.c_int -_c_ncclCommInitRank.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int -] - -ncclDataType_t = ctypes.c_int - - -class ncclDataTypeEnum: - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -ncclRedOp_t = ctypes.c_int - - -class ncclRedOpTypeEnum: - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -# equivalent to c declaration: -# ncclResult_t ncclAllReduce( -# const void* sendbuff, void* recvbuff, size_t count, -# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, -# udaStream_t stream); -# note that cudaStream_t is a pointer type, so the last argument is a pointer -_c_ncclAllReduce = nccl.ncclAllReduce -_c_ncclAllReduce.restype = ctypes.c_int -_c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t, - ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p -] - -# be cautious! this is a collective call, it will block until all -# processes in the communicator have called this function. -# because Python object destruction can happen in random order, -# it is better not to call it at all. -# equivalent to c declaration: -# ncclResult_t ncclCommDestroy(ncclComm_t comm); -_c_ncclCommDestroy = nccl.ncclCommDestroy -_c_ncclCommDestroy.restype = ctypes.c_int -_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] - - -class NCCLCommunicator: +class PyNcclCommunicator: def __init__( self, group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, + library_path: Optional[str] = None, ): """ Args: group: the process group to work on. If None, it will use the default process group. - device: the device to bind the NCCLCommunicator to. If None, + device: the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "NCCLCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group.") self.group = group # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + if self.rank == 0: - self.unique_id = ncclGetUniqueId() + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() else: - self.unique_id = NcclUniqueId() + # construct an empty unique id + self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) # arg `src` in `broadcast` is the global rank @@ -246,7 +77,6 @@ def __init__( byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte - self.comm = ctypes.c_void_p() if device is None: local_rank = get_local_rank() device = torch.device(f"cuda:{local_rank}") @@ -261,15 +91,25 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - NCCL_CHECK( - _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank)) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() + # A small all_reduce for warmup. + self.all_reduce(torch.zeros(1, device=device)) + self.stream.synchronize() + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + if self.disabled: + return # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -278,10 +118,32 @@ def all_reduce(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - NCCL_CHECK( - _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream))) + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py deleted file mode 100644 index 44e4f39217a41..0000000000000 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -import contextlib -from typing import Optional - -import torch -from torch.distributed import ProcessGroup, ReduceOp - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetVersion) -except Exception as e: - # in non-NVIDIA environments, we can't import the nccl module - # e.g. when running on machines with AMD GPUs - logger.info("Failed to import NCCL library: %s", e) - logger.info("It is expected if you are not running on NVIDIA GPUs.") - pass - -comm: Optional["NCCLCommunicator"] = None - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return comm is not None - - -@contextlib.contextmanager -def set_pynccl_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - try: - assert comm is not None - comm.stream = stream - yield - finally: - pass - - -def init_process_group(group: Optional[ProcessGroup] = None) -> None: - assert not is_initialized() - global comm - logger.info("vLLM is using nccl==%s", ncclGetVersion()) - comm = NCCLCommunicator(group=group) - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - assert comm is not None - comm.all_reduce(input_, op) - - -def destroy_process_group() -> None: - global comm - comm = None - - -def get_world_size() -> int: - """Returns the world size.""" - assert comm is not None - return comm.world_size - - -def get_nccl_backend() -> Optional["NCCLCommunicator"]: - return comm diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000..43d85674b23d0 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,258 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library, nccl_integrity_check + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + # load the library in another process. + # if it core dumps, it will not crash the current process + nccl_integrity_check(so_file) + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "One solution is to download libnccl2 version 2.18 from " + "https://developer.download.nvidia.com/compute/cuda/repos/ " + "and extract the libnccl.so.2 file. If you already have the " + "library, please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be5bb4e857caf..5075da11bb1b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,10 +3,10 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib -from typing import Optional +from typing import List, Optional import torch +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger @@ -14,10 +14,11 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TP_DEVICE_GROUP = None -_TP_CPU_GROUP = None +_TP_DEVICE_GROUP: Optional[ProcessGroup] = None +_TP_CPU_GROUP: Optional[ProcessGroup] = None +_TP_PYNCCL_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None +_PP_DEVICE_GROUP: Optional[ProcessGroup] = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -41,11 +42,16 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None +_PP_GLOBAL_RANKS: Optional[List[int]] = None _LOCAL_RANK = -1 +def get_tp_pynccl_communicator(): + global _TP_PYNCCL_COMMUNICATOR + return _TP_PYNCCL_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -80,10 +86,20 @@ def init_distributed_environment( # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1 and distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank global _LOCAL_RANK _LOCAL_RANK = local_rank + # A small all_reduce for warmup. + data = torch.zeros(1) + if torch.cuda.is_available(): + data = data.to(device=f"cuda:{local_rank}") + torch.distributed.all_reduce(data) def initialize_model_parallel( @@ -133,29 +149,36 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) group = torch.distributed.new_group(ranks, backend=backend) cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + global _PP_DEVICE_GROUP + global _PP_GLOBAL_RANKS + assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks + _PP_DEVICE_GROUP = group + _PP_GLOBAL_RANKS = ranks def ensure_model_parallel_initialized( @@ -188,8 +211,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP_DEVICE_GROUP is not None - and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): @@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group(): def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + assert _PP_DEVICE_GROUP is not None, ( "pipeline model parallel group is not initialized") - return _PIPELINE_MODEL_PARALLEL_GROUP + return _PP_DEVICE_GROUP def get_tensor_model_parallel_world_size(): @@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") - return _PIPELINE_GLOBAL_RANKS[0] + return _PP_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + return _PP_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that precedes the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def destroy_model_parallel(): @@ -295,45 +317,12 @@ def destroy_model_parallel(): if _TP_CPU_GROUP: torch.distributed.destroy_process_group(_TP_CPU_GROUP) _TP_CPU_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - if _PIPELINE_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _PIPELINE_GLOBAL_RANKS - _PIPELINE_GLOBAL_RANKS = None - from vllm.distributed.device_communicators import pynccl_utils - - # Destroy the pynccl states if any. - pynccl_utils.destroy_process_group() - - -# Whether to use pynccl for nccl all reduce. -# We use pynccl for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_PYNCCL_FOR_ALL_REDUCE = False - - -@contextlib.contextmanager -def with_pynccl_for_all_reduce(): - from vllm.distributed.device_communicators import pynccl_utils - """use pynccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. - yield - else: - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - old = _ENABLE_PYNCCL_FOR_ALL_REDUCE - _ENABLE_PYNCCL_FOR_ALL_REDUCE = True - - stream = torch.cuda.current_stream() - with pynccl_utils.set_pynccl_stream(stream): - yield - _ENABLE_PYNCCL_FOR_ALL_REDUCE = old - - -def is_pynccl_enabled_for_all_reduce(): - """check if pynccl is enabled for all reduce""" - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - return _ENABLE_PYNCCL_FOR_ALL_REDUCE + global _TP_PYNCCL_COMMUNICATOR + _TP_PYNCCL_COMMUNICATOR = None + + global _PP_DEVICE_GROUP + if _PP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) + _PP_DEVICE_GROUP = None + global _PP_GLOBAL_RANKS + _PP_GLOBAL_RANKS = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bb8245eb307f7..5c2acbef13129 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -83,6 +83,7 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None @@ -467,6 +468,13 @@ def add_cli_args( 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + parser.add_argument( '--ngram-prompt-lookup-max', type=int, @@ -508,7 +516,7 @@ def add_cli_args( return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -547,6 +555,8 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9f72a0d11974f..37a2dc77a3b50 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, - Optional, Set, Tuple, Type, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, + Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -327,7 +327,7 @@ def __init__(self, # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected - self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None + self._background_loop_unshielded: Optional[asyncio.Task] = None self.start_engine_loop = start_engine_loop self._errored_with: Optional[BaseException] = None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3ed660e183360..71620139fba39 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -238,17 +238,25 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() - pbar = tqdm(total=num_requests, - desc="Processed prompts", - dynamic_ncols=True) + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=f"Generation Speed: {0:.2f} toks/s", + ) # Run the engine. outputs: List[RequestOutput] = [] + total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished: outputs.append(output) if use_tqdm: + total_toks += (sum( + len(stp.token_ids) for stp in output.outputs)) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() @@ -256,4 +264,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs \ No newline at end of file + return outputs diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f9e294af47253..362f28d05c3bb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,7 +4,7 @@ import re from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Any, Set +from typing import Optional, Set import fastapi import uvicorn @@ -34,7 +34,7 @@ openai_serving_completion: OpenAIServingCompletion logger = init_logger(__name__) -_running_tasks: Set[asyncio.Task[Any]] = set() +_running_tasks: Set[asyncio.Task] = set() @asynccontextmanager @@ -164,15 +164,32 @@ async def authentication(request: Request, call_next): served_model_names = args.served_model_name else: served_model_names = [args.model] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - openai_serving_chat = OpenAIServingChat(engine, served_model_names, + + event_loop: Optional[asyncio.AbstractEventLoop] + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + openai_serving_chat = OpenAIServingChat(engine, model_config, + served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules) app.root_path = args.root_path uvicorn.run(app, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c8f4a6b315db0..1b469fc59b076 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,4 +1,3 @@ -import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, @@ -8,6 +7,7 @@ from openai.types.chat import (ChatCompletionContentPartParam, ChatCompletionRole) +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -35,17 +35,47 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, + model_config: ModelConfig, served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules, - await_post_init=self._load_chat_template( - chat_template=chat_template)) + lora_modules=lora_modules) self.response_role = response_role + self._load_chat_template(chat_template) + + def _load_chat_template(self, chat_template: Optional[str]): + tokenizer = self.tokenizer + + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) + else: + logger.warning( + "No chat template provided. Chat API will not work.") def _parse_chat_message_content( self, @@ -357,36 +387,4 @@ async def chat_completion_full_generator( usage=usage, ) - return response - - async def _load_chat_template(self, chat_template: Optional[str]): - while self.tokenizer is None: - # Give the parent class time to load the tokenizer - await asyncio.sleep(0.1) - tokenizer = self.tokenizer - - if chat_template is not None: - try: - with open(chat_template, "r") as f: - tokenizer.chat_template = f.read() - except OSError as e: - JINJA_CHARS = "{}\n" - if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") - raise ValueError(msg) from e - - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - tokenizer.chat_template = codecs.decode( - chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", - tokenizer.chat_template) - elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) - else: - logger.warning( - "No chat template provided. Chat API will not work.") + return response \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6a7f29c4c96f2..158d8ed7fbbf5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -4,6 +4,7 @@ from fastapi import Request +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (CompletionRequest, CompletionResponse, @@ -52,11 +53,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: class OpenAIServingCompletion(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]] = None): + lora_modules: Optional[List[LoRAModulePath]]): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 21baea2e5e7f6..f10718c5f3d80 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,12 @@ -import asyncio import json from dataclasses import dataclass from http import HTTPStatus -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import Field -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Annotated +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, ErrorResponse, @@ -29,13 +28,24 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]], - await_post_init: Optional[Awaitable[Any]] = None): + lora_modules: Optional[List[LoRAModulePath]]): + super().__init__() + self.engine = engine + self.max_model_len = model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + truncation_side="left") + self.served_model_names = served_model_names + if lora_modules is None: self.lora_requests = [] else: @@ -47,38 +57,6 @@ def __init__(self, ) for i, lora in enumerate(lora_modules, start=1) ] - self.max_model_len = 0 - # Lazy initialized - self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - event_loop.create_task(self._post_init(await_post_init)) - else: - # When using single vLLM without engine_use_ray - asyncio.run(self._post_init(await_post_init)) - - async def _post_init(self, await_post_init): - engine_model_config = await self.engine.get_model_config() - self.max_model_len = engine_model_config.max_model_len - - # A separate tokenizer to map token IDs to strings. - self.tokenizer = get_tokenizer( - engine_model_config.tokenizer, - tokenizer_mode=engine_model_config.tokenizer_mode, - tokenizer_revision=engine_model_config.tokenizer_revision, - trust_remote_code=engine_model_config.trust_remote_code, - truncation_side="left") - - if await_post_init is not None: - await await_post_init - async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 1af3bcf380843..fa3480fa64837 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -82,6 +82,10 @@ def _init_spec_worker(self): draft_worker_kwargs.update( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=self.speculative_config. + ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.speculative_config. + ngram_prompt_lookup_min, # TODO allow draft-model specific load config. #load_config=self.load_config, ) @@ -89,6 +93,8 @@ def _init_spec_worker(self): spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=self.speculative_config. + speculative_disable_by_batch_size, ) assert self.parallel_config.world_size == 1, ( diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 1720566840bb1..ffdc32b7339af 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,5 +1,5 @@ # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Union import torch import torch.nn as nn @@ -51,10 +51,9 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, @@ -88,7 +87,7 @@ def can_replace_layer(cls, source_layer: nn.Module, ) -def _mcp_apply_weights(x, bias, layer): +def _mcp_apply(x, bias, layer): """ MergedColumnParallelLinearWithShardedLoRA and QKVParallelLinearWithShardedLora share the same @@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer): """ # expecting 2 for column parallel and 3 for qkv n = len(layer.lora_a_stacked) - output = layer.base_layer.linear_method.apply_weights( - layer.base_layer, x, bias) + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape @@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA( Based on S-LoRA, slicing happens along the rank dim. """ - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None: + return lora_a output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] - for i in range(2) + lora_a[0][:, + output_start_idx:output_start_idx + output_shard_size], + lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] ] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return _mcp_apply_weights(x, bias, self) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): Based on S-LoRA, slicing happens along the rank dim. """ - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None: + return lora_a shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ - lora_a[i][:, start_idx[i]:start_idx[i] + - shard_size[i]] if lora_a[i] is not None else None - for i in range(3) + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], + lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] ] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return _mcp_apply_weights(x, bias, self) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -218,9 +225,8 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = lora_b[:, start_idx:end_idx] return lora_b - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x) + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index b3609666b2ec7..90f63c34fb2d3 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -145,11 +145,15 @@ def __post_init__(self): class BaseLayerWithLoRA(nn.Module): - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: """Slice lora a if splitting for tensor parallelism.""" ... - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: """Slice lora b if splitting with tensor parallelism.""" ... @@ -539,10 +543,16 @@ def reset_lora(self, index: int): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: return lora_a - def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_b[0] is None or lora_b[1] is None: + return lora_b shard_size = self.output_dim start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size @@ -767,10 +777,15 @@ def reset_lora(self, index: int): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: return lora_a - def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + lora_b_q, lora_b_k, lora_b_v = None, None, None if lora_b[0] is not None: lora_b_q = lora_b[0][:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * @@ -992,7 +1007,6 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight if hasattr( self.base_layer, "weight") else self.base_layer.qweight diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 50d7e9133e0e8..cd45040bcca5d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -119,6 +119,16 @@ def __init__( self.rank = rank self.loras: Dict[str, LoRALayerWeights] = loras + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ec3c10c591a18..377f561cceaf2 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Set, Type +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Set, Type, Union import torch @@ -25,6 +26,17 @@ def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, self.device = device self.lora_config = lora_config + # If False, do not cache. If None, cache is empty. + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + @abstractproperty def is_enabled(self) -> bool: ... @@ -174,9 +186,15 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_loras(): return False - return self._lora_manager.add_lora( - self._lora_manager.create_dummy_lora(lora_request.lora_int_id, - rank, self.embedding_modules)) + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._lora_manager.create_dummy_lora( + lora_request.lora_int_id, rank, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._lora_manager.add_lora(dummy_lora) def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id in self.list_loras(): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c89c62b..2926c7d1c8a76 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_experts, fused_moe, fused_topk, get_config_file_name) __all__ = [ "fused_moe", + "fused_topk", + "fused_experts", "get_config_file_name", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ac1dc60b4650d..eb513f36ce8cb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,62 +308,19 @@ def get_moe_configs(E: int, N: int, return None -def fused_moe( +def fused_topk( hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, num_expert_group: int = 0, topk_group: int = 0, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. +): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + M, _ = hidden_states.shape - E, N, _ = w1.shape + if is_hip(): # The MoE kernels are not yet supported on ROCm. routing_weights = torch.softmax(gating_output, @@ -406,6 +363,33 @@ def fused_moe( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + M, _ = hidden_states.shape + E, N, _ = w1.shape if override_config: config = override_config @@ -490,3 +474,63 @@ def fused_moe( out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1c652e347d4ad..5798bc359dcf2 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -19,6 +21,7 @@ "squeezellm": SqueezeLLMConfig, "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, + "deepspeedfp": DeepSpeedFPConfig } diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000000000..31cdffbcf0ab9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,194 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b57e1dde81a5f..ff996741c1d00 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -231,9 +231,14 @@ def apply(self, # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.act_scale is None and x_scale computed from x. # If static, layer.act_scale is scalar and x_scale set to act_scale. - qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) - - # Fused GEMM_DQ + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. output, _ = torch._scaled_mm( qinput, layer.weight, @@ -243,7 +248,7 @@ def apply(self, bias=bias, ) - return output + return torch.narrow(output, 0, 0, x.shape[0]) def all_close_1d(x: torch.Tensor) -> bool: diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 5edbbf2c70a49..b5f1e55d0e839 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -12,15 +12,21 @@ class RejectionSampler(nn.Module): https://arxiv.org/pdf/2302.01318.pdf. """ - def __init__(self, strict_mode: bool = False): + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): """Create a rejection sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. """ super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens self._strict_mode = strict_mode # NOTE: A "bonus token" is accepted iff all proposal tokens are @@ -312,7 +318,8 @@ def _create_output( # proposal methods that require KV cache. We can fix it by "prefilling" # the bonus token in the proposer. The following issue tracks the fix. # https://github.com/vllm-project/vllm/issues/4212 - output_with_bonus_tokens[:, -1] = -1 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 # Fill the recovered token ids. output.mul_(~after_false_mask).add_( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index fd671c1f47a75..0b2f90402d383 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -109,7 +109,7 @@ def _forward( key_pass = key[..., self.rotary_dim:] self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + positions.device, dtype=query.dtype) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -143,7 +143,8 @@ def forward( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1f19d2053d996..e52e350d2726f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -782,13 +782,14 @@ def _get_logprobs( top_logprobs, top_token_ids = torch.topk(logprobs, largest_num_logprobs, dim=-1) - top_logprobs = top_logprobs.cpu() - top_token_ids = top_token_ids.cpu() else: top_logprobs, top_token_ids = None, None - selected_logprobs = selected_logprobs.cpu() - ranks = ranks.cpu() + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + if top_logprobs is not None and top_token_ids is not None: + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') # Find prompt/sample logprobs. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] @@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: prompt_logprobs = [] num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) - for token_id in next_prompt_tokens: + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): # Calculate the prompt logprob of the real prompt tokens. - # Use tuple here for performance (to use to_list()). # {token_id: (logprob, rank_from_vocab)} prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { - token_id: (selected_logprobs[selected_logprobs_idx].item(), - ranks[selected_logprobs_idx].item()) + token_id: (selected_logprob_items[idx], rank_items[idx]) } # Add top K prompt logprobs along with its rank. if num_logprobs > 0: - prompt_logprobs_dict.update( - zip( - top_token_ids[top_logprob_idx, :num_logprobs].tolist(), - zip( - top_logprobs[ - top_logprob_idx, :num_logprobs].tolist(), - # This is ranks. Since top_logprob is sorted, - # we can just use a range here. - range(1, num_logprobs + 1)))) + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) prompt_logprobs.append({ token_id: Logprob(*logprob_and_rank) for token_id, logprob_and_rank in prompt_logprobs_dict.items() }) # + 1 to go to the next prompt token. top_logprob_idx += 1 - selected_logprobs_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) return prompt_logprobs, top_logprob_idx, selected_logprobs_idx @@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed( ): """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 + num_logprobs = seq_group.sampling_params.logprobs or 0 sampled_logprobs: SampleLogprobs = [] next_token_ids, parent_seq_ids = sample_result if seq_group.do_sample: assert len(next_token_ids) > 0 - for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): - # Calculate the sample logprob of the real sampled tokens. - # Use tuple here for performance (to use to_list()). - # token_id: (logprob, rank_from_vocab) - sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { - next_token_id: - (selected_logprobs[selected_logprobs_idx].item(), - ranks[selected_logprobs_idx].item()) + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, + parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: (selected_logprob_items[idx], rank_items[idx]) } - # +1 to go to the next sampled token. Note that - # selected_logprobs can contain duplicates unlike top_logprobs - # when beam search is enabled. - selected_logprobs_idx += 1 - - # Second, add top K logprobs along with its rank. - if num_logprobs >= 0: - sampled_logprobs_dict.update( - zip( - top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist(), - zip( - top_logprobs[top_logprob_idx + - parent_id, :num_logprobs].tolist(), - # This is rank. Since top_logprob is sorted, we - # can just use a range here. - range(1, num_logprobs + 1)))) + # Get top K logprobs. + if num_logprobs > 0: + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[top_logprob_idx + + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) + sampled_logprobs.append({ token_id: Logprob(*logprob_and_rank) for token_id, logprob_and_rank in sampled_logprobs_dict.items() }) - # There are len(seq_ids) number of sampled tokens for the current - # sequence group in top_logprobs. Jump to the next seq_group. + + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. top_logprob_idx += len(seq_ids) return sampled_logprobs, top_logprob_idx, selected_logprobs_idx diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index af433b86e604d..219a2a392e129 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -tensorizer_load_fail = None +tensorizer_error_msg = None try: from tensorizer import (DecryptionParams, EncryptionParams, @@ -28,7 +28,7 @@ from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) except ImportError as e: - tensorizer_load_fail = e + tensorizer_error_msg = str(e) __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', @@ -254,11 +254,11 @@ class TensorizerAgent: def __init__(self, tensorizer_config: TensorizerConfig, quant_config: QuantizationConfig, **extra_kwargs): - if tensorizer_load_fail is not None: + if tensorizer_error_msg is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " - "to use this feature with `pip install vllm[tensorizer]`." - ) from tensorizer_load_fail + "to use this feature with `pip install vllm[tensorizer]`. " + "Error message: {}".format(tensorizer_error_msg)) self.tensorizer_config = tensorizer_config self.tensorizer_args = ( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c5c42f1515ed8..9070c01a37a28 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -57,6 +57,7 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000000000..796cef7c4a735 --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,521 @@ +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.arctic import ArcticConfig + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + layer_id: int, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMLP, self).__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + self.layer_id = layer_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + layer_id: int, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMoE, self).__init__() + + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + layer_id=layer_id, + quant_config=quant_config, + reduce_results=reduce_results) + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config) + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + layer_idx, + quant_config=quant_config) + self.block_sparse_moe = ArcticMoE( + config, + layer_id=layer_idx, + quant_config=quant_config, + reduce_results=(not self.use_residual)) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + layer_id=layer_idx, + is_residual_mlp=True, + reduce_results=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +class ArcticModel(nn.Module): + + def __init__( + self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.layers = nn.ModuleList([ + ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + attn_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module): + + def __init__(self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + **kwargs) -> None: + super().__init__() + self.config = config + self.model = ArcticModel(config, quant_config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + ) + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping = [] + expert_params_mapping = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index f2939eff7959b..3cebb85b49d27 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,7 +2,7 @@ import copy import enum from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest @@ -612,6 +612,12 @@ def __init__( self._token_chunk_size = token_chunk_size self.do_sample = do_sample + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + self.num_speculative_tokens = None + if self._token_chunk_size is None: if is_prompt: self._token_chunk_size = list(seq_data.values())[0].get_len() @@ -741,12 +747,12 @@ class ExecuteModelRequest: """The model execution request.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict) + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5044cc1ef85fd..20098ebaeea32 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,4 +1,5 @@ import copy +import weakref from typing import List, Tuple import torch @@ -32,7 +33,7 @@ def init_device(self): super().init_device() self._proposer = Top1Proposer( - self, + weakref.proxy(self), self.device, self.vocab_size, max_proposal_len=self.max_model_len, diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index fed8be42054a5..6cd50fcc1a041 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,3 +1,4 @@ +import weakref from typing import List, Optional, Tuple import torch @@ -37,7 +38,7 @@ def init_device(self): # Current only support Top1Proposer self._proposer = Top1Proposer( - self, + weakref.proxy(self), device=self.device, vocab_size=self.vocab_size, ) @@ -138,7 +139,7 @@ def sampler_output( SamplerOutput( outputs=None, sampled_token_probs=token_probs[i], - logprobs=token_logprobs, + logprobs=token_logprobs[i], sampled_token_ids=token_ids[i], )) return outputs, False diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c2b119fbd5036..a4e759095b294 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch @@ -54,30 +54,33 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): def create_worker( cls, scorer_worker: WorkerBase, - draft_worker_kwargs, + draft_worker_kwargs: Dict[str, Any], + disable_by_batch_size: Optional[int], ) -> "SpecDecodeWorker": - if "ngram_prompt_lookup_max" in draft_worker_kwargs: - ngram_prompt_lookup_max = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_max")) - ngram_prompt_lookup_min = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - else: - ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + disable_bonus_tokens = True if ngram_prompt_lookup_max > 0: + disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) else: proposer_worker = MultiStepWorker(**draft_worker_kwargs) + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + return SpecDecodeWorker( proposer_worker, scorer_worker, - # TODO(cade) disable strict mode for speedup. - rejection_sampler=RejectionSampler(strict_mode=True), - ) + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, )) def __init__( self, @@ -85,6 +88,7 @@ def __init__( scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, ): """ Create a SpecDecodeWorker. @@ -97,11 +101,14 @@ def __init__( Worker. rejection_sampler: A Torch module used to perform modified rejection sampling for speculative decoding. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker + self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler self._metrics = AsyncMetricsCollector( @@ -199,27 +206,41 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + disable_all = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + if disable_all: + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # This happens for prefill, or when the spec decode is disabled + # for this batch. if execute_model_req.num_lookahead_slots == 0 or len( execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req) + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all) return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") - def _run_no_spec( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to the - proposer and scorer model so that the KV cache is consistent between the - two. + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a prefill step, without any speculation. The input is sent to + the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. """ - #logger.info("run proposer worker no spec") - - self.proposer_worker.execute_model(execute_model_req) + if not skip_proposer: + self.proposer_worker.execute_model(execute_model_req) - #logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -244,22 +265,18 @@ def _run_speculative_decoding_step( sequence. """ - #logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) - #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) - #logger.info("verify proposals") accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) - #logger.info("create output list") return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index eb622a0e2e7f4..ee9462b68dae8 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -56,7 +56,7 @@ def get_proposals( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative @@ -97,17 +97,27 @@ def get_proposals( return proposals - def _split_by_max_model_len( + def _split_by_proposal_len( self, seq_group_metadata_list: List[SequenceGroupMetadata], proposal_len: int, ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length.""" + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ proposal_lens: List[int] = [] nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_indices: List[int] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has been disabled + # (e.g. due to high traffic). + if seq_group_metadata.num_speculative_tokens == 0: + proposal_lens.append(0) + continue + seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_len = seq_data.get_len() @@ -115,13 +125,14 @@ def _split_by_max_model_len( # are supported. # If max_proposal_len is defined, then we shall no exccess this # quota for nonzero_proposal + new_k = 0 if (self.max_proposal_len is None or seq_len + proposal_len < self.max_proposal_len): - proposal_lens.append(proposal_len) + new_k = proposal_len nonzero_proposal_len_seqs.append(seq_group_metadata) nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k return ( proposal_lens, diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 0000000000000..7780bf5e78d6d --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,204 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoraConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret diff --git a/vllm/utils.py b/vllm/utils.py index 6479a8dab320a..f0e71f5e99b64 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -329,7 +329,7 @@ def _generate_random_fp8( from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c34ee0648626b..1fb63a3e47921 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,5 +1,5 @@ """CacheEngine class for managing the KV cache.""" -from typing import Dict, List +from typing import List import torch @@ -67,17 +67,17 @@ def _allocate_kv_cache( device=device)) return kv_cache - def swap_in(self, src_to_dst: Dict[int, int]) -> None: + def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) - def swap_out(self, src_to_dst: Dict[int, int]) -> None: + def swap_out(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 193b021b7a11e..6c8b1685dadcf 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -4,8 +4,9 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -26,6 +27,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], @@ -39,27 +41,22 @@ def __init__( self.scheduler_config = scheduler_config # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend(self.model_config.dtype) # Lazy initialization. self.model: nn.Module # Set after init_Model - self.block_size: int # Set after initial profiling. def load_model(self) -> None: self.model = get_model( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4420d4cc9e12f..5e4ae564cb57e 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -151,6 +151,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, @@ -248,9 +249,9 @@ def _init_cache_engine(self) -> None: def cache_copy( self, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() @@ -269,6 +270,9 @@ def execute_model( num_seq_groups: int = len(seq_group_metadata_list) assert execute_model_req is not None blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 data: Dict[str, Any] = { diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab248596490f6..3fc76c6142165 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import contextlib import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -9,12 +8,12 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.attention.backends.flashinfer import FlashInferBackend -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce -from vllm.distributed.device_communicators import (custom_all_reduce, - pynccl_utils) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.communication_op import graph_capture_mode +from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -107,6 +106,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", @@ -116,48 +116,40 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker + self.vision_language_config = vision_language_config - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() - # Set after load_model. - self.lora_manager: LRUCacheWorkerLoRAManager = None - + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - - self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture - if self.model_config is not None else 0) - - self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = kv_cache_dtype - self.vision_language_config = vision_language_config - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) - - # Lazy initialization - self.model: torch.nn.Module # Set after load_model - self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables: torch.Tensor # Set after initial profiling. + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + self.attn_backend = get_attn_backend(self.model_config.dtype) + # Lazy initialization + self.model: torch.nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -212,13 +204,6 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") - def set_block_size(self, block_size: int) -> None: - self.block_size = block_size - - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) - def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size @@ -395,7 +380,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": attn_metadata = self.attn_backend.make_metadata( is_prompt=True, use_cuda_graph=False, @@ -556,7 +541,7 @@ def _prepare_decode( device=self.device, ) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html @@ -836,20 +821,22 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. @@ -929,10 +916,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - self.pynccl_backend = pynccl_utils.get_nccl_backend() - assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " @@ -1058,7 +1041,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): + with graph_capture_mode(): self.model( input_ids, positions, @@ -1073,7 +1056,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): + with graph_capture_mode(): hidden_states = self.model( input_ids, positions, @@ -1125,16 +1108,6 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -@contextlib.contextmanager -def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): - with with_pynccl_for_all_reduce(): - yield - else: - yield - - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4add36e94f723..0ca9c2b64cf30 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,7 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - get_tensor_model_parallel_cpu_group, init_distributed_environment) -from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -75,6 +73,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, @@ -184,7 +183,6 @@ def _init_cache_engine(self): self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -195,17 +193,16 @@ def _warm_up_model(self) -> None: def cache_swap( self, - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + blocks_to_swap_in: torch.Tensor, + blocks_to_swap_out: torch.Tensor, + blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if blocks_to_swap_in: + if blocks_to_swap_in.numel() > 0: self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out: + if blocks_to_swap_out.numel() > 0: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() @@ -219,13 +216,29 @@ def execute_model( else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list + blocks_to_swap_in: torch.Tensor + blocks_to_swap_out: torch.Tensor + blocks_to_copy: torch.Tensor if self.is_driver_worker: assert seq_group_metadata_list is not None assert execute_model_req is not None num_seq_groups = len(seq_group_metadata_list) - blocks_to_swap_in = execute_model_req.blocks_to_swap_in - blocks_to_swap_out = execute_model_req.blocks_to_swap_out - blocks_to_copy = execute_model_req.blocks_to_copy + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor( + execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor( + execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, @@ -291,29 +304,10 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - if pynccl_utils.is_initialized(): - pynccl_world_size = pynccl_utils.get_world_size() - if pynccl_world_size != parallel_config.world_size: - raise RuntimeError( - "pynccl is already initialized but the pynccl world " - "size does not match parallel_config.world_size " - f"({pynccl_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1: - # NOTE(woosuk): We don't initialize pynccl process group when world size - # is 1. - # NOTE(kaichao): By default, pynccl is initialized for tp group. - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) - # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: init_custom_ar() - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype.