From 0ad78d20a57d87a12cec9dcad2f6ff8dea1895c2 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 5 Aug 2024 10:12:46 +0000 Subject: [PATCH 01/19] style --- Dockerfile_amd | 2 +- server/Makefile-flash-att-v2 | 2 +- server/Makefile-vllm | 4 ++-- .../layers/attention/rocm.py | 16 +++++++++++----- .../custom_modeling/flash_llama_modeling.py | 4 ++++ 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 51231638c95..514891a89ee 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \ # Install flash-attention, torch dependencies RUN pip install numpy einops ninja --no-cache-dir -RUN conda install intel::mkl-static intel::mkl-include +RUN conda install mkl-static mkl-include RUN pip uninstall -y triton && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ cd triton/python && \ diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index dbddd0f41b0..035273292db 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 +flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130 build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/Makefile-vllm b/server/Makefile-vllm index f1f805290e2..bf4a1498659 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b -commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 +commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ @@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda build-vllm-rocm: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/fxmarty/rocm-vllm.git vllm; \ + git clone https://github.com/mht-sharma/vllm.git vllm; \ fi cd vllm && git fetch && git checkout $(commit_rocm) && \ PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69e641629a8..77ba4c9250f 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -14,7 +14,7 @@ ENGINE = "triton" if use_triton else "ck" try: - from vllm._C import cache_ops + import vllm._custom_ops as ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" @@ -33,9 +33,7 @@ def reshape_and_cache( key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) + ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def paged_attention( @@ -78,7 +76,7 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + import vllm._custom_ops as ops use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: @@ -180,6 +178,7 @@ def attention( softmax_scale, window_size_left=-1, causal=True, + softcap=0.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -194,12 +193,19 @@ def attention( out, cu_seqlens, cu_seqlens, + None, + None, + None, + None, max_s, max_s, 0.0, softmax_scale, False, causal, + window_size_left, + 0, + softcap, False, None, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87dfd..56d88956f8f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -313,11 +313,15 @@ def __init__(self, prefix, config, weights, index): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize + self.hidden_size = config.hidden_size + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and self.hidden_size + != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. and not self.quantize ): out = torch.empty( From 55e6059eb10b13eb1c02849b69d73237fa41dcca Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 5 Aug 2024 12:47:21 +0000 Subject: [PATCH 02/19] update torch --- Dockerfile_amd | 38 ++++++++++--------- .../models/flash_causal_lm.py | 3 +- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 514891a89ee..399bc869550 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -98,24 +98,26 @@ RUN pip uninstall -y triton && \ cd triton/python && \ pip install . -RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir - -ARG _GLIBCXX_USE_CXX11_ABI="1" -ARG CMAKE_PREFIX_PATH="/opt/conda" -ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -ARG BUILD_CAFFE2="0" \ - BUILD_CAFFE2_OPS="0" \ - USE_CUDA="0" \ - USE_ROCM="1" \ - BUILD_TEST="0" \ - USE_FBGEMM="0" \ - USE_NNPACK="0" \ - USE_QNNPACK="0" \ - USE_XNNPACK="0" \ - USE_FLASH_ATTENTION="1" \ - USE_MEM_EFF_ATTENTION="0" - -RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir + +# ARG _GLIBCXX_USE_CXX11_ABI="1" +# ARG CMAKE_PREFIX_PATH="/opt/conda" +# ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" +# ARG BUILD_CAFFE2="0" \ +# BUILD_CAFFE2_OPS="0" \ +# USE_CUDA="0" \ +# USE_ROCM="1" \ +# BUILD_TEST="0" \ +# USE_FBGEMM="0" \ +# USE_NNPACK="0" \ +# USE_QNNPACK="0" \ +# USE_XNNPACK="0" \ +# USE_FLASH_ATTENTION="1" \ +# USE_MEM_EFF_ATTENTION="0" + +# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install + +RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1 # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36bb26621f8..9b8704478ab 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1150,8 +1150,7 @@ def warmup(self, batch: FlashCausalLMBatch): elif CUDA_GRAPHS is not None: tuning_sequences = CUDA_GRAPHS else: - # For seqlen = 1, we dispatch to LLMM1 kernel. - tuning_sequences = [2, 3, 4, 5, 6, 7] + tuning_sequences = [1, 2, 3, 4, 5, 6, 7] tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, From 5788c942a53e98bfb3c63d04d9d111c1c53465ed Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 6 Aug 2024 10:29:46 +0000 Subject: [PATCH 03/19] ix issues --- Dockerfile_amd | 39 +++++++++---------- .../layers/attention/rocm.py | 2 +- .../models/flash_causal_lm.py | 2 +- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 399bc869550..efc80234b09 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -98,26 +98,25 @@ RUN pip uninstall -y triton && \ cd triton/python && \ pip install . -# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir - -# ARG _GLIBCXX_USE_CXX11_ABI="1" -# ARG CMAKE_PREFIX_PATH="/opt/conda" -# ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -# ARG BUILD_CAFFE2="0" \ -# BUILD_CAFFE2_OPS="0" \ -# USE_CUDA="0" \ -# USE_ROCM="1" \ -# BUILD_TEST="0" \ -# USE_FBGEMM="0" \ -# USE_NNPACK="0" \ -# USE_QNNPACK="0" \ -# USE_XNNPACK="0" \ -# USE_FLASH_ATTENTION="1" \ -# USE_MEM_EFF_ATTENTION="0" - -# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install - -RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1 +RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir && \ + git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 + +ARG _GLIBCXX_USE_CXX11_ABI="1" +ARG CMAKE_PREFIX_PATH="/opt/conda" +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" +ARG BUILD_CAFFE2="0" \ + BUILD_CAFFE2_OPS="0" \ + USE_CUDA="0" \ + USE_ROCM="1" \ + BUILD_TEST="0" \ + USE_FBGEMM="0" \ + USE_NNPACK="0" \ + USE_QNNPACK="0" \ + USE_XNNPACK="0" \ + USE_FLASH_ATTENTION="1" \ + USE_MEM_EFF_ATTENTION="0" + +RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 77ba4c9250f..da8a4bcd259 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -208,7 +208,7 @@ def attention( softcap, False, None, - ) + )[0] elif ENGINE == "triton": from .flash_attn_triton import triton_attention diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 9b8704478ab..174bba65f93 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1159,7 +1159,7 @@ def warmup(self, batch: FlashCausalLMBatch): log_master( logger.info, - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", + f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): From d61f7e63fa1e940387ae52f6265d803a71f98d8d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 6 Aug 2024 12:39:49 +0000 Subject: [PATCH 04/19] fix clone --- Dockerfile_amd | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index efc80234b09..129a58fcfae 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -92,14 +92,17 @@ RUN chmod +x ~/mambaforge.sh && \ # Install flash-attention, torch dependencies RUN pip install numpy einops ninja --no-cache-dir -RUN conda install mkl-static mkl-include +RUN conda install intel::mkl-static intel::mkl-include RUN pip uninstall -y triton && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ cd triton/python && \ pip install . -RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir && \ - git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 +RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ + cd pytorch && git fetch --depth 1 origin da320214e66b5af0f7db8fd18a64dbb519d17b27 && \ + git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 && \ + pip install -r requirements.txt --no-cache-dir + ARG _GLIBCXX_USE_CXX11_ABI="1" ARG CMAKE_PREFIX_PATH="/opt/conda" From e5578555582e3cc44d871ca631d1b9e8e7d7eeb2 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 6 Aug 2024 12:46:35 +0000 Subject: [PATCH 05/19] revert mkl --- Dockerfile_amd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 129a58fcfae..d8bb8c477f7 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \ # Install flash-attention, torch dependencies RUN pip install numpy einops ninja --no-cache-dir -RUN conda install intel::mkl-static intel::mkl-include +RUN conda install mkl-static mkl-include RUN pip uninstall -y triton && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ cd triton/python && \ From ff0505e7f967b063de3babd70d94a278e7dfed49 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 4 Sep 2024 05:46:28 +0000 Subject: [PATCH 06/19] added custom PA --- Dockerfile_amd | 46 +++++++--- server/Makefile-flash-att-v2 | 2 +- server/Makefile-vllm | 2 +- .../layers/attention/cuda.py | 1 + .../layers/attention/ipex.py | 1 + .../layers/attention/rocm.py | 87 ++++++++++++++----- .../custom_modeling/flash_cohere_modeling.py | 1 + .../custom_modeling/flash_dbrx_modeling.py | 1 + .../flash_deepseek_v2_modeling.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 1 + .../custom_modeling/flash_gemma_modeling.py | 1 + .../custom_modeling/flash_gpt2_modeling.py | 1 + .../custom_modeling/flash_llama_modeling.py | 3 + .../custom_modeling/flash_mistral_modeling.py | 2 + .../custom_modeling/flash_mixtral_modeling.py | 1 + .../custom_modeling/flash_neox_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 1 + .../custom_modeling/flash_qwen2_modeling.py | 1 + .../custom_modeling/flash_rw_modeling.py | 2 + .../flash_santacoder_modeling.py | 1 + .../flash_starcoder2_modeling.py | 1 + .../models/flash_causal_lm.py | 4 +- 22 files changed, 128 insertions(+), 35 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index d8bb8c477f7..fd612af588a 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -39,7 +39,7 @@ COPY launcher launcher RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base +FROM rocm/dev-ubuntu-22.04:6.2 AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -48,23 +48,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ git \ make \ + libmsgpack-dev \ libssl-dev \ + llvm-dev \ g++ \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev \ - hipblaslt-dev \ + hipcub-dev \ rocblas-dev \ hiprand-dev \ + hipfft-dev \ rocrand-dev \ miopen-hip-dev \ - hipfft-dev \ - hipcub-dev \ hipsolver-dev \ rccl-dev \ cmake \ - python3-dev && \ + python3-dev \ + python3-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml @@ -74,7 +76,30 @@ ARG ROCM_VERSION='6.0.2' ARG PYTHON_VERSION='3.10.10' # Automatically set by buildx ARG TARGETPLATFORM -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH + +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \ + && chmod +x cmake-3.30.2-linux-x86_64.sh \ + && ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \ + && rm cmake-3.30.2-linux-x86_64.sh + +RUN pip install joblib msgpack + +# Install HIPBLASLt +ARG HIPBLASLT_BRANCH="6f65c6e" +RUN git clone https://github.com/ROCm/hipBLASLt \ + && cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \ + && cd build/release \ + && make package +RUN dpkg -i hipBLASLt/build/release/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; + # && cd .. \ + # && rm -rf hipBLASLt # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba @@ -98,15 +123,15 @@ RUN pip uninstall -y triton && \ cd triton/python && \ pip install . +ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27" RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ - cd pytorch && git fetch --depth 1 origin da320214e66b5af0f7db8fd18a64dbb519d17b27 && \ - git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 && \ + cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \ + git checkout ${PYTORCH_COMMIT} && \ + git submodule update --init --recursive && \ pip install -r requirements.txt --no-cache-dir - ARG _GLIBCXX_USE_CXX11_ABI="1" ARG CMAKE_PREFIX_PATH="/opt/conda" -ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" ARG BUILD_CAFFE2="0" \ BUILD_CAFFE2_OPS="0" \ USE_CUDA="0" \ @@ -221,4 +246,3 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] -CMD ["--json-output"] diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 035273292db..74293d9cc96 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130 +flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28) build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/Makefile-vllm b/server/Makefile-vllm index bf4a1498659..18dcc4a0c53 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b -commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b +commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc105..c623b7f9c99 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -42,6 +42,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + num_kv_heads: int, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index e0956b26c44..f7aada34f76 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -58,6 +58,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + num_kv_heads: int, ): out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index da8a4bcd259..bd03301784b 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -8,11 +8,17 @@ major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE = 512 + +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_CUSTOM = 256 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" +custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0" +if custom_attn_available: + from vllm._custom_C import paged_attention_custom + try: import vllm._custom_ops as ops except Exception as e: @@ -45,6 +51,7 @@ def paged_attention( block_tables: torch.Tensor, input_lengths: Seqlen, max_s: int, + num_kv_heads: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -66,6 +73,22 @@ def paged_attention( # value_cache => [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + + gqa_ratio = num_heads // num_kv_heads + use_custom = ( + custom_attn_available + and (query.dtype == torch.half or query.dtype == torch.bfloat16) + and (head_size == 128 or head_size == 64) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_s <= 32768 + ) + + if not use_custom: + _PARTITION_SIZE = _PARTITION_SIZE_V1V2 + else: + _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = input_lengths.input_lengths @@ -78,7 +101,11 @@ def paged_attention( # to parallelize. import vllm._custom_ops as ops - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = ( + max_s <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512) + and not use_custom + ) if use_v1: ops.paged_attention_v1( out, @@ -110,24 +137,44 @@ def paged_attention( ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) + if not use_custom: + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + paged_attention_custom( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + ) + return out diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9aac..460228548ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -313,6 +313,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1efc38..bc9e8f15cbe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -352,6 +352,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c298e..d02d6cd385f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -380,6 +380,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) # Remove padding. @@ -424,6 +425,7 @@ def __init__(self, prefix: str, config, weights, intermediate_size: int): def forward(self, hidden_states: torch.Tensor, reduce: bool = True): if ( SYSTEM == "rocm" + and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and not self.quantize diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index de86f51490e..76d8e8ba793 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -256,6 +256,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, softcap=self.softcap, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadbe70..42ae24f31d8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -248,6 +248,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8cccb..8c9dc6d6287 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -247,6 +247,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 56d88956f8f..9c4cb64ac6f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -235,6 +235,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj( @@ -318,6 +319,7 @@ def __init__(self, prefix, config, weights, index): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" + and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and self.hidden_size @@ -557,6 +559,7 @@ def forward( adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( inputs_embeds, position_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3dd8..e4ba02951d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -235,6 +235,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj( @@ -300,6 +301,7 @@ def __init__(self, prefix: str, config, weights, layer_id): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" + and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and not self.quantize diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c9f7..c84c99c65c5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -292,6 +292,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b1b03ad755a..103c84a7b96 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -180,6 +180,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e18348093..e067e1d89ad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -209,6 +209,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85de43..8b3b8322599 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -153,6 +153,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 708641e7f48..db4c7e7eb5e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -223,6 +223,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -341,6 +342,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c267678226e..2d92f3ff902 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -308,6 +308,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb89662..534a3792c14 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -258,6 +258,7 @@ def forward( block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 174bba65f93..35388f4979e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1007,12 +1007,12 @@ def init_kv_cache( else: self.kv_cache = [ ( - torch.empty( + torch.zeros( (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), dtype=dtype, device=device, ), - torch.empty( + torch.zeros( (num_blocks, num_heads, head_size, BLOCK_SIZE), dtype=dtype, device=device, From 88e2997b9c47b725dcb48d68b5c8ea448b45d591 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 6 Sep 2024 12:23:18 +0000 Subject: [PATCH 07/19] style --- Dockerfile_amd | 2 ++ docs/source/installation_amd.md | 6 ++++++ server/text_generation_server/layers/attention/rocm.py | 2 +- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index fd612af588a..90a1ddf5e1f 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -152,6 +152,7 @@ ENV HIP_FORCE_DEV_KERNARG=1 # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 FROM base AS kernel-builder @@ -246,3 +247,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 931a9e3adc4..8bf608306f1 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. +## Custom PagedAttention + +For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`. + +The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel. + ## Unsupported features The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index bd03301784b..58165dc75d5 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -15,7 +15,7 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" -custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0" +custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" if custom_attn_available: from vllm._custom_C import paged_attention_custom From 3f2dc6150009b74c71d3cfeb6ea9d12bbc2f9f7c Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 9 Sep 2024 10:13:59 +0000 Subject: [PATCH 08/19] fix style --- Dockerfile_amd | 1 + docs/source/installation_amd.md | 4 ++++ server/text_generation_server/layers/linear.py | 6 +++++- server/text_generation_server/models/flash_causal_lm.py | 3 ++- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 90a1ddf5e1f..f01b160d541 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -153,6 +153,7 @@ ENV HIP_FORCE_DEV_KERNARG=1 # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 FROM base AS kernel-builder diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 8bf608306f1..070e268eb1e 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -25,6 +25,10 @@ Experimentally, on MI300X, we noticed a 6-8% latency improvement when using Tuna TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container. +TunableOps tuning is disabled by default after the warmup phase. If you wish to keep tuning enabled for the entire run, set the environment variable `PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=1`. + +Note: With tuning enabled, every time a new input shape is encountered, tuning will be performed, which can slow down the first inference for that shape. + ## Flash attention implementation Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py). diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 12d7f83aaf1..78815d74468 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -61,7 +61,11 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: weight = self.weight bias = self.bias - if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: + if ( + SYSTEM == "rocm" + and inp.numel() // inp.shape[-1] == 1 + and inp.dtype == torch.float16 + ): batched = False inp_shape = inp.shape diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 35388f4979e..7a3f57abac3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1175,7 +1175,8 @@ def warmup(self, batch: FlashCausalLMBatch): log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) - torch.cuda.tunable.tuning_enable(False) + if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": + torch.cuda.tunable.tuning_enable(False) else: log_master( logger.info, From f3bc03843003f7ad8f1aa8e85bd29b85264e2b2a Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 11 Sep 2024 06:52:30 +0000 Subject: [PATCH 09/19] style --- server/text_generation_server/models/flash_causal_lm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7a3f57abac3..bb8560783c7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1162,6 +1162,10 @@ def warmup(self, batch: FlashCausalLMBatch): f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) + torch.cuda.tunable.set_filename( + tunableop_filepath, insert_device_ordinal=False + ) + if os.path.isfile(tunableop_filepath): log_master( logger.info, From e2f48fae3d32c5b50d75edba271ebe0182dee979 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 11 Sep 2024 07:00:29 +0000 Subject: [PATCH 10/19] hide env vart --- docs/source/installation_amd.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 070e268eb1e..8bf608306f1 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -25,10 +25,6 @@ Experimentally, on MI300X, we noticed a 6-8% latency improvement when using Tuna TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container. -TunableOps tuning is disabled by default after the warmup phase. If you wish to keep tuning enabled for the entire run, set the environment variable `PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=1`. - -Note: With tuning enabled, every time a new input shape is encountered, tuning will be performed, which can slow down the first inference for that shape. - ## Flash attention implementation Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py). From 0345816477b00be417c602105a950d7482f734c1 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 11 Sep 2024 10:52:10 +0000 Subject: [PATCH 11/19] fix mixtral model --- Dockerfile_amd | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile_amd b/Dockerfile_amd index f01b160d541..d8f16e7ebc0 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -154,6 +154,7 @@ ENV HIP_FORCE_DEV_KERNARG=1 ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 +ENV VLLM_MOE_PADDING=0 FROM base AS kernel-builder From 59fd0cbdff68c20d954aafdecd419b4152f34e5e Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 12 Sep 2024 13:16:13 +0000 Subject: [PATCH 12/19] add skinny kernel and merge fixes --- Dockerfile_amd | 10 ++-- .../layers/attention/cuda.py | 1 - .../layers/attention/ipex.py | 1 - .../layers/attention/rocm.py | 47 ++++++++++--------- .../text_generation_server/layers/linear.py | 44 ++++++++++++----- .../custom_modeling/flash_cohere_modeling.py | 6 +-- .../custom_modeling/flash_dbrx_modeling.py | 6 +-- .../flash_deepseek_v2_modeling.py | 6 +-- .../custom_modeling/flash_gemma2_modeling.py | 7 ++- .../custom_modeling/flash_gemma_modeling.py | 7 ++- .../custom_modeling/flash_gpt2_modeling.py | 7 ++- .../custom_modeling/flash_gptj_modeling.py | 5 +- .../custom_modeling/flash_llama_modeling.py | 6 +-- .../custom_modeling/flash_mistral_modeling.py | 6 +-- .../custom_modeling/flash_mixtral_modeling.py | 6 +-- .../custom_modeling/flash_neox_modeling.py | 7 ++- .../custom_modeling/flash_phi_modeling.py | 7 ++- .../custom_modeling/flash_qwen2_modeling.py | 7 ++- .../custom_modeling/flash_rw_modeling.py | 12 ++--- .../flash_santacoder_modeling.py | 7 ++- .../flash_starcoder2_modeling.py | 7 ++- .../models/flash_causal_lm.py | 3 +- .../text_generation_server/models/globals.py | 7 +++ 23 files changed, 121 insertions(+), 101 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 1940b985192..2aa2a6bc44f 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -152,9 +152,6 @@ ENV HIP_FORCE_DEV_KERNARG=1 # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 -ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 -ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 -ENV VLLM_MOE_PADDING=0 FROM base AS kernel-builder @@ -245,6 +242,13 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 +ENV VLLM_MOE_PADDING=0 +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV ROCM_USE_SKINNY_GEMM=1 + COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 592350f4c6e..4b588b5cf40 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -45,7 +45,6 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, - num_kv_heads: int, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 83254598a3e..2d1427ae672 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -62,7 +62,6 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, - num_kv_heads: int, softcap: Optional[float] = None, ): out = torch.empty_like(query) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 3e003acbde0..0835cb97264 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -50,9 +50,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, - num_kv_heads: int, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -76,6 +75,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + num_kv_heads = key_cache.shape[1] gqa_ratio = num_heads // num_kv_heads use_custom = ( custom_attn_available @@ -92,7 +92,7 @@ def paged_attention( _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = input_lengths.input_lengths + input_lengths = seqlen.input_lengths out = torch.empty_like(query) @@ -220,10 +220,10 @@ def paged_attention( def attention( q, - k, - v, - cu_seqlens, - max_s, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -237,17 +237,17 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, None, None, None, None, - max_s, - max_s, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, @@ -264,26 +264,27 @@ def attention( def attention( q, - k, - v, - cu_seqlens, - max_s, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, + softcap=0.0, ): out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, causal, softmax_scale, ) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 78815d74468..69b6294bbb2 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,12 +1,19 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F +import os if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( + "true", + "1", + ) + + if ROCM_USE_SKINNY_GEMM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") class FastLinear(torch.nn.Module): @@ -48,6 +55,14 @@ def __init__( else: self.bias = None + self.cu_count = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + self.use_skinny_gemm = ( + ROCM_USE_SKINNY_GEMM + and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName + ) + @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_tensor(f"{prefix}.weight") @@ -62,9 +77,9 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: bias = self.bias if ( - SYSTEM == "rocm" - and inp.numel() // inp.shape[-1] == 1 + self.use_skinny_gemm and inp.dtype == torch.float16 + and inp.shape[-1] % 8 == 0 ): batched = False inp_shape = inp.shape @@ -73,13 +88,16 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: inp = inp.view(-1, inp_shape[-1]) batched = True - m, k = weight.shape[0], inp_shape[1] - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" - ) - if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): - _custom_C.LLMM1(weight, inp, out, 8) - elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] + if m > 8 and n <= 4: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) + _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) _custom_C.LLMM1(weight, inp, out, 4) else: out = F.linear(inp, weight) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 8f6cba350f3..b0e57d68653 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -297,8 +298,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, @@ -314,7 +315,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 478b5b16217..8bce4e573ac 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -336,8 +337,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -353,7 +354,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0aa948e751f..561363816f4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed from text_generation_server.layers import ( @@ -363,8 +364,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, @@ -380,7 +381,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) # Remove padding. diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 6bd4aac5c6f..1ad88801b39 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -25,7 +26,6 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -237,8 +237,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -257,7 +257,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, softcap=self.softcap, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index c253e6abe8b..a401798a687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -25,7 +26,6 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -249,7 +249,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 90382583eb4..33f20b9a3bb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, @@ -248,7 +248,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index ef071d46d20..f2197069217 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -192,8 +193,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 39218531f4b..6be89297059 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,6 +28,7 @@ from transformers.activations import ACT2FN from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import ( paged_attention, attention, @@ -220,8 +221,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -237,7 +238,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dacda101838..3b56bbab0e3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -218,8 +219,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -236,7 +237,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index b35688ec87b..3451158bf5a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -275,8 +276,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -293,7 +294,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 698f0343e00..2d3be430be3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -26,7 +27,6 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -172,8 +172,8 @@ def forward( # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], - kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], + kv_cache[0] if PAGED_KV else qkv[:, 1], + kv_cache[1] if PAGED_KV else qkv[:, 2], seqlen, block_tables, self.softmax_scale, @@ -189,7 +189,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 38e8c884118..76e406a7427 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -1,3 +1,4 @@ +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -25,7 +26,6 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) -from text_generation_server.utils.import_utils import SYSTEM class PhiConfig(PretrainedConfig): @@ -194,8 +194,8 @@ def forward( if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -211,7 +211,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index d43401badc5..0f0dbf5ec92 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -1,3 +1,4 @@ +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -21,7 +22,6 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix, weights): @@ -137,8 +137,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -155,7 +155,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 765cf39ebfd..ba516881029 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,11 +1,11 @@ from typing import List, Optional, Tuple +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( SpeculativeHead, TensorParallelColumnLinear, @@ -207,8 +207,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -224,7 +224,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -326,8 +325,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), - kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), + kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(), + kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(), seqlen, block_tables, self.softmax_scale, @@ -343,7 +342,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 80c280c8735..fa074606678 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,3 +1,4 @@ +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -22,7 +23,6 @@ from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_multi_mqa( @@ -293,8 +293,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], - kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], + kv_cache[0] if PAGED_KV else key_value[:, 0], + kv_cache[1] if PAGED_KV else key_value[:, 1], seqlen, block_tables, self.softmax_scale, @@ -310,7 +310,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 0c4ce05ae84..30d35632485 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -47,7 +48,6 @@ PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight -from text_generation_server.utils.import_utils import SYSTEM class Starcoder2Config(PretrainedConfig): @@ -242,8 +242,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -260,7 +260,6 @@ def forward( block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6f21e84642c..c4bf8a57e9e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1379,6 +1379,7 @@ def tunableop_warmup(self, seqlen: int): cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) + max_s = seqlen seqlen = Seqlen( input_lengths=input_lengths, prefix_lengths=prefix_lens_tensor, @@ -1396,7 +1397,7 @@ def tunableop_warmup(self, seqlen: int): block_tables=None, seqlen=seqlen, slots=slots, - max_s=seqlen, + max_s=max_s, lm_head_indices=None, prefill_cache_indices=None, ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6c518c2caa5..f04c6df52c5 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,6 +4,7 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +from text_generation_server.utils.import_utils import SYSTEM PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") @@ -52,6 +53,12 @@ # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None +PAGED_KV: bool +if SYSTEM in {"rocm", "ipex"}: + PAGED_KV = False +else: + PAGED_KV = True + def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX From 4fb947d2aaf178bc37c791287594e8d2a0df181d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 19 Sep 2024 14:28:21 +0000 Subject: [PATCH 13/19] fixed style --- .../text_generation_server/layers/__init__.py | 3 + .../layers/moe/fused_moe_rocm.py | 193 ++++++++++++++++++ .../layers/moe/unquantized.py | 15 +- .../flash_deepseek_v2_modeling.py | 9 +- 4 files changed, 217 insertions(+), 3 deletions(-) create mode 100644 server/text_generation_server/layers/moe/fused_moe_rocm.py diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index 0000ca915fd..e8282b1640f 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -19,6 +19,8 @@ TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.moe.fused_moe_rocm import grouped_topk + __all__ = [ "get_linear", "FastLinear", @@ -31,4 +33,5 @@ "TensorParallelAdapterRowLinear", "load_layer_norm", "load_conv2d", + "grouped_topk", ] diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py new file mode 100644 index 00000000000..ab30ff536af --- /dev/null +++ b/server/text_generation_server/layers/moe/fused_moe_rocm.py @@ -0,0 +1,193 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Dict, Any + +import torch +import torch.distributed + + +# TODO: Remove the functions once moe_kernel are built for ROCM +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 8f1d9b3fb98..d9d62c0ef47 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -6,7 +6,9 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights -if SYSTEM != "ipex": +if SYSTEM == "rocm": + from vllm.model_executor.layers.fused_moe import fused_moe +elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_moe @@ -52,6 +54,17 @@ def __init__( ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + if SYSTEM == "rocm": + return fused_moe( + x, + self.gate_up_proj, + self.down_proj, + gating_output, + self.topk, + renormalize=self.renormalize, + inplace=True, + ) + return fused_moe( x, w1=self.gate_up_proj, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 08a8d258ab3..fccafb01e90 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -16,7 +16,13 @@ from typing import List, Optional, Tuple from text_generation_server.models.globals import PAGED_KV -from moe_kernels.fused_moe import grouped_topk +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "rocm": + from text_generation_server.layers import grouped_topk +else: + from vllm.model_executor.layers.fused_moe import grouped_topk + import torch import torch.distributed from text_generation_server.layers import ( @@ -36,7 +42,6 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights from torch import nn from transformers.activations import ACT2FN From 64e981fdcf08c1750b75593777aa50d65bfe6a6f Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 24 Sep 2024 10:53:19 +0000 Subject: [PATCH 14/19] fix issue for sliding window models --- server/text_generation_server/layers/attention/common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index 855f4dfc0f6..d6e512c0172 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ATTENTION import torch from typing import Optional @@ -65,5 +66,7 @@ class Seqlen: max_k: int def clamp(self, max): + if SYSTEM == "rocm": + return self raise NotImplementedError("Not implemented seqlen for paged") return Seqlen(torch.clamp(self.input_lengths, max=max)) From 829144d15afe4ab79047c0d5a2fa1385e2b7380d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 27 Sep 2024 10:28:37 +0000 Subject: [PATCH 15/19] addressed review comments --- Dockerfile_amd | 14 +- .../layers/attention/__init__.py | 14 +- .../layers/attention/cuda.py | 6 + .../layers/attention/ipex.py | 1 + .../layers/attention/rocm.py | 39 +++-- .../layers/moe/fused_moe_rocm.py | 143 +----------------- .../custom_modeling/flash_cohere_modeling.py | 6 +- .../custom_modeling/flash_dbrx_modeling.py | 6 +- .../flash_deepseek_v2_modeling.py | 6 +- .../custom_modeling/flash_gemma2_modeling.py | 6 +- .../custom_modeling/flash_gemma_modeling.py | 6 +- .../custom_modeling/flash_gpt2_modeling.py | 6 +- .../custom_modeling/flash_gptj_modeling.py | 6 +- .../custom_modeling/flash_llama_modeling.py | 6 +- .../custom_modeling/flash_mistral_modeling.py | 6 +- .../custom_modeling/flash_mixtral_modeling.py | 6 +- .../custom_modeling/flash_neox_modeling.py | 6 +- .../custom_modeling/flash_phi_modeling.py | 6 +- .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../custom_modeling/flash_rw_modeling.py | 10 +- .../flash_santacoder_modeling.py | 6 +- .../flash_starcoder2_modeling.py | 6 +- .../text_generation_server/models/globals.py | 7 - 23 files changed, 101 insertions(+), 223 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index dabcb77a84f..766881a8866 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -67,14 +67,11 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins hipsolver-dev \ rccl-dev \ cmake \ - python3.11-dev \ python3.11-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 -ARG PYTORCH_VERSION='2.3.0' -ARG ROCM_VERSION='6.0.2' ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM @@ -82,11 +79,6 @@ ENV PATH=/opt/conda/bin:$PATH ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \ - && chmod +x cmake-3.30.2-linux-x86_64.sh \ - && ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \ - && rm cmake-3.30.2-linux-x86_64.sh - # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba # translating Docker's TARGETPLATFORM into mamba arches @@ -111,7 +103,7 @@ RUN case ${TARGETPLATFORM} in \ /opt/conda/bin/conda clean -ya # Install flash-attention, torch dependencies -RUN pip install numpy einops ninja joblib msgpack --no-cache-dir +RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir # Install HIPBLASLt ARG HIPBLASLT_BRANCH="6f65c6e" @@ -129,7 +121,8 @@ RUN dpkg -i hipBLASLt/build/release/*.deb \ RUN pip uninstall -y triton && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ cd triton/python && \ - pip install . + pip install . && \ + rm -r triton ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27" RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ @@ -153,6 +146,7 @@ ARG BUILD_CAFFE2="0" \ USE_MEM_EFF_ATTENTION="0" RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +RUN rm -rf pytorch # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index 56fc5319415..2134d857d65 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -13,9 +13,19 @@ SUPPORTS_WINDOWING, ) elif SYSTEM == "rocm": - from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .rocm import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "ipex": - from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .ipex import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 4b588b5cf40..6c6457707fa 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -351,3 +351,9 @@ def attention( None, ) return out + + +# Prefill in the cache with every kind of attention, unless we +# have a configuration that requires flash-attention v1, which +# does not support block tables. +PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index d0eadc75375..657c90af432 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -5,6 +5,7 @@ from typing import Optional SUPPORTS_WINDOWING = False +PREFILL_IN_KV_CACHE = False def attention( diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 0835cb97264..be6158c1775 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -16,9 +16,18 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" -custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" -if custom_attn_available: - from vllm._custom_C import paged_attention_custom +PREFILL_IN_KV_CACHE = False + +use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" +try: + if use_rocm_custom_paged_attn: + from vllm._custom_C import paged_attention_custom +except ImportError as e: + log_master( + logger.info, + f"Custom Paged Attention not available. Complete error: {e}", + ) + use_rocm_custom_paged_attn = False try: import vllm._custom_ops as ops @@ -71,6 +80,9 @@ def paged_attention( # limitations under the License. # + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") + # value_cache => [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape @@ -78,7 +90,7 @@ def paged_attention( num_kv_heads = key_cache.shape[1] gqa_ratio = num_heads // num_kv_heads use_custom = ( - custom_attn_available + use_rocm_custom_paged_attn and (query.dtype == torch.half or query.dtype == torch.bfloat16) and (head_size == 128 or head_size == 64) and (block_size == 16 or block_size == 32) @@ -224,10 +236,10 @@ def attention( value_cache: torch.Tensor, seqlen: Seqlen, block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: float = 0.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -268,11 +280,14 @@ def attention( value_cache: torch.Tensor, seqlen: Seqlen, block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: float = 0.0, ): + if softcap is not None: + raise NotImplementedError("softcap is only available with CK flash attn") + out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py index ab30ff536af..68accb99022 100644 --- a/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ b/server/text_generation_server/layers/moe/fused_moe_rocm.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Dict, Any +from typing import Tuple import torch import torch.distributed @@ -50,144 +50,3 @@ def grouped_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids - - -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - return config - - -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -): - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - - import triton.language as tl - from vllm import _custom_ops as ops - from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_moe_configs, - invoke_fused_moe_kernel, - moe_align_block_size, - ) - - M, _ = hidden_states.shape - E, N, _ = w1.shape - - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - if inplace: - return torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states, - ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index b0e57d68653..44db0290f9c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -298,8 +298,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 8bce4e573ac..852e52d8f74 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -337,8 +337,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 94c7600a8d1..97a269309bb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -15,7 +15,7 @@ from typing import List, Optional, Tuple -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": @@ -333,8 +333,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 1ad88801b39..b1f0dba2dec 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -237,8 +237,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index a401798a687..3ddcba8a380 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -231,8 +231,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 33f20b9a3bb..d47bb104209 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -231,8 +231,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index f2197069217..200735c61ee 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -193,8 +193,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6be89297059..a77ec2344ac 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,7 +28,7 @@ from transformers.activations import ACT2FN from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, @@ -221,8 +221,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3b56bbab0e3..d05032773f1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -219,8 +219,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index abfa737a73e..3eb81daf332 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -274,8 +274,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 2d3be430be3..471abca3add 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -172,8 +172,8 @@ def forward( # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0] if PAGED_KV else qkv[:, 1], - kv_cache[1] if PAGED_KV else qkv[:, 2], + kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], + kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 76e406a7427..4a18090a7f2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -1,4 +1,4 @@ -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -194,8 +194,8 @@ def forward( if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 0f0dbf5ec92..00e63a6c36e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -1,4 +1,4 @@ -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -137,8 +137,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index ba516881029..2cf243e8974 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,6 +1,6 @@ from typing import List, Optional, Tuple -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed from torch import nn @@ -207,8 +207,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -325,8 +325,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(), - kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(), + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index fa074606678..0c1518e7e6d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,4 +1,4 @@ -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -293,8 +293,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key_value[:, 0], - kv_cache[1] if PAGED_KV else key_value[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 30d35632485..22ac0240c81 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -242,8 +242,8 @@ def forward( # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index f04c6df52c5..6c518c2caa5 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,6 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -from text_generation_server.utils.import_utils import SYSTEM PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") @@ -53,12 +52,6 @@ # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None -PAGED_KV: bool -if SYSTEM in {"rocm", "ipex"}: - PAGED_KV = False -else: - PAGED_KV = True - def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX From 816d4b67b2e7bff347d50c930eb654b4c0c6e85e Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 27 Sep 2024 12:32:17 +0000 Subject: [PATCH 16/19] fix import --- server/text_generation_server/layers/attention/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index 2134d857d65..a2f97700818 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -11,6 +11,7 @@ paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) elif SYSTEM == "rocm": from .rocm import ( @@ -18,6 +19,7 @@ paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) elif SYSTEM == "ipex": from .ipex import ( @@ -25,6 +27,7 @@ paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") @@ -35,5 +38,6 @@ "paged_attention", "reshape_and_cache", "SUPPORTS_WINDOWING", + "PREFILL_IN_KV_CACHE", "Seqlen", ] From ac2dccd1740ec222c33e637d49b2293e6012f4af Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 27 Sep 2024 12:34:04 +0000 Subject: [PATCH 17/19] improved error messag --- server/text_generation_server/layers/linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 69b6294bbb2..08306d57969 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -13,7 +13,9 @@ try: from vllm import _custom_C except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + raise ImportError( + f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" + ) class FastLinear(torch.nn.Module): From a24c2cc5e9c26eb3d0331a1fd607db6771c3d743 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 27 Sep 2024 12:39:12 +0000 Subject: [PATCH 18/19] updated default value --- server/text_generation_server/layers/attention/rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index be6158c1775..de7d673fe2b 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -283,7 +283,7 @@ def attention( softmax_scale: float, window_size_left: int = -1, causal: bool = True, - softcap: float = 0.0, + softcap: Optional[float] = None, ): if softcap is not None: raise NotImplementedError("softcap is only available with CK flash attn") From 346dfe398af02c7355fe12f2ffbf09223d232ef1 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 27 Sep 2024 12:59:35 +0000 Subject: [PATCH 19/19] remove import --- server/text_generation_server/layers/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index e8282b1640f..0000ca915fd 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -19,8 +19,6 @@ TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.moe.fused_moe_rocm import grouped_topk - __all__ = [ "get_linear", "FastLinear", @@ -33,5 +31,4 @@ "TensorParallelAdapterRowLinear", "load_layer_norm", "load_conv2d", - "grouped_topk", ]