From e8d3da00814ec7773d33edd5643bb885d85686cb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 9 Oct 2024 11:53:56 -0400 Subject: [PATCH] upgrade pytorch from 2.4.0 => 2.4.1 (#1950) * upgrade pytorch from 2.4.0 => 2.4.1 * update xformers for updated pytorch version * handle xformers version case for torch==2.3.1 --- .github/workflows/base.yml | 2 +- .github/workflows/main.yml | 4 ++-- .github/workflows/nightlies.yml | 4 ++-- .github/workflows/tests-nightly.yml | 4 ++-- .github/workflows/tests.yml | 4 ++-- requirements.txt | 2 +- setup.py | 7 +++++++ 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5e8c8fc33d..1b24f2c970 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,7 @@ jobs: cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "124" cuda_version: 12.4.1 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5a972f5f08..c27dbedefa 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -84,7 +84,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 1d95a0983f..17c76c24e7 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -26,7 +26,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -83,7 +83,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 30ed397cef..8c9e1f49e7 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -91,7 +91,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: nightly_build: "true" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c104e92c27..a798bdd5cd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -94,7 +94,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: steps: diff --git a/requirements.txt b/requirements.txt index 123a4ee54a..41bfdfbeb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.6.3 sentencepiece wandb einops -xformers==0.0.27 +xformers==0.0.28.post1 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index 1b64fadaef..e939bc37ee 100644 --- a/setup.py +++ b/setup.py @@ -49,10 +49,17 @@ def parse_requirements(): else: raise ValueError("Invalid version format") + if (major, minor) >= (2, 4): + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") if (major, minor) >= (2, 3): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1")