From bf5be90d390988b0afc17acec8edf756ee968e1a Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 21 Aug 2024 12:21:51 -0400 Subject: [PATCH] bump torch to <2.5 (#142) --- .github/workflows/pr-gpu.yaml | 8 ++++---- pyproject.toml | 2 +- setup.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 1ca8d5b..d94b057 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -21,14 +21,14 @@ jobs: fail-fast: false matrix: include: - - name: "python3.11-pytorch2.3.1-gpus1" + - name: "python3.11-pytorch2.4.0-gpus1" gpu_num: 1 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 - - name: "python3.11-pytorch2.3.1-gpus2" + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 + - name: "python3.11-pytorch2.4.0-gpus2" gpu_num: 2 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 steps: - name: Run PR GPU tests uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2 diff --git a/pyproject.toml b/pyproject.toml index c72dbdf..fc8f3dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # build requirements [build-system] -requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"] +requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4.1"] build-backend = "setuptools.build_meta" # Pytest diff --git a/setup.py b/setup.py index fa15ee4..e39dd49 100644 --- a/setup.py +++ b/setup.py @@ -62,15 +62,15 @@ install_requires = [ 'numpy>=1.21.5,<2.1.0', 'packaging>=21.3.0,<24.2', - 'torch>=2.3.0,<2.4', + 'torch>=2.3.0,<2.4.1', 'triton>=2.1.0', - 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', + 'stanford-stk @ git+https://git@github.com/eitanturok/stk.git@bump-version', ] extra_deps = {} extra_deps['gg'] = [ - 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb', + 'grouped_gemm @ git+https://git@github.com/eitanturok/grouped_gemm.git@bump-version', ] extra_deps['dev'] = [