diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 28084b7fb4..8e30554475 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -15,6 +15,8 @@ jobs: base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 - name: '2.0.1_cu118' base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 + - name: '2.1.0_cu121' + base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 steps: - name: Maximize Build Space on Worker diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 6af87346c8..efdf8eec58 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -27,6 +27,10 @@ jobs: container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04 markers: 'not gpu' pytest_command: 'coverage run -m pytest' + - name: 'cpu-2.1.0' + container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04 + markers: 'not gpu' + pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index d228802ddc..f0650f6179 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -27,6 +27,10 @@ jobs: container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 markers: 'gpu' pytest_command: 'coverage run -m pytest' + - name: 'gpu-2.1.0' + container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 + markers: 'gpu' + pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index dbd6ff6352..7d517269fc 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -460,7 +460,7 @@ def _set_state_dict_type(model: nn.Module): # load state dict into the new optimizer opt_state_dict_slice = FSDP.optim_state_dict_to_load( - opt_state_dict, mod_new, opt_new) + optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) opt_new.load_state_dict(opt_state_dict_slice) new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new)