From 161bcb6517ef38fee60ce5b243b6ac68d8bbdef7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Dec 2023 09:38:20 -0500 Subject: [PATCH] Dockerfile torch fix (#987) * add torch to requirements.txt at build time to force version to stick * fix xformers check * better handling of xformers based on installed torch version * fix for ci w/o torch --- .github/workflows/base.yml | 2 +- .github/workflows/main.yml | 4 ++-- docker/Dockerfile | 1 - setup.py | 15 +++++++++------ 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5f08854842..1dbff114ed 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,7 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9514208b1c..87b308362b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: @@ -80,7 +80,7 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: diff --git a/docker/Dockerfile b/docker/Dockerfile index 41915de83d..81a08bc8b7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets -RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ else \ diff --git a/setup.py b/setup.py index 42fd22df11..fe4d2cfad8 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ """setup.py for axolotl""" +from importlib.metadata import PackageNotFoundError, version + from setuptools import find_packages, setup @@ -22,12 +24,13 @@ def parse_requirements(): # Handle standard packages _install_requires.append(line) - # TODO(wing) remove once xformers release supports torch 2.1.0 - if "torch==2.1.0" in _install_requires: - _install_requires.pop(_install_requires.index("xformers>=0.0.22")) - _install_requires.append( - "xformers @ git+https://github.com/facebookresearch/xformers.git@main" - ) + try: + torch_version = version("torch") + if torch_version.startswith("2.1.1"): + _install_requires.pop(_install_requires.index("xformers==0.0.22")) + _install_requires.append("xformers==0.0.23") + except PackageNotFoundError: + pass return _install_requires, _dependency_links