From e42e2dac2f0b5b8ce7a4cf743ca369a29e78c306 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 5 Aug 2024 12:35:46 +0000 Subject: [PATCH] update torch --- Dockerfile_amd | 38 ++++++++++--------- .../models/flash_causal_lm.py | 3 +- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 514891a89ee..399bc869550 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -98,24 +98,26 @@ RUN pip uninstall -y triton && \ cd triton/python && \ pip install . -RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir - -ARG _GLIBCXX_USE_CXX11_ABI="1" -ARG CMAKE_PREFIX_PATH="/opt/conda" -ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -ARG BUILD_CAFFE2="0" \ - BUILD_CAFFE2_OPS="0" \ - USE_CUDA="0" \ - USE_ROCM="1" \ - BUILD_TEST="0" \ - USE_FBGEMM="0" \ - USE_NNPACK="0" \ - USE_QNNPACK="0" \ - USE_XNNPACK="0" \ - USE_FLASH_ATTENTION="1" \ - USE_MEM_EFF_ATTENTION="0" - -RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir + +# ARG _GLIBCXX_USE_CXX11_ABI="1" +# ARG CMAKE_PREFIX_PATH="/opt/conda" +# ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" +# ARG BUILD_CAFFE2="0" \ +# BUILD_CAFFE2_OPS="0" \ +# USE_CUDA="0" \ +# USE_ROCM="1" \ +# BUILD_TEST="0" \ +# USE_FBGEMM="0" \ +# USE_NNPACK="0" \ +# USE_QNNPACK="0" \ +# USE_XNNPACK="0" \ +# USE_FLASH_ATTENTION="1" \ +# USE_MEM_EFF_ATTENTION="0" + +# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install + +RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1 # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36bb26621f8..9b8704478ab 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1150,8 +1150,7 @@ def warmup(self, batch: FlashCausalLMBatch): elif CUDA_GRAPHS is not None: tuning_sequences = CUDA_GRAPHS else: - # For seqlen = 1, we dispatch to LLMM1 kernel. - tuning_sequences = [2, 3, 4, 5, 6, 7] + tuning_sequences = [1, 2, 3, 4, 5, 6, 7] tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE,