Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 6, 2023
1 parent 9bc5f7f commit 9d211f6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def __init__(self,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
if version.parse(torch.__version__) >= version.parse('2.1.0') and error_correction:
if version.parse(torch.__version__) >= version.parse(
'2.1.0') and error_correction:
raise RuntimeError(
'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0')
'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0'
)

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down
14 changes: 8 additions & 6 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import os
import time
import warnings
import contextlib

import numpy as np
import packaging.version as version
Expand Down Expand Up @@ -415,10 +415,12 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
if version.parse(torch.__version__) < version.parse('2.0.1'):
pytest.skip(f'This test requires torch 2.0.1 or greater.')

error_context = contextlib.null_context()
if use_errors and version.parse(torch.__version__) >= version.parse('2.1.0'):
error_context = pytest.raises(RuntimeError, match='DecoupledLionW_8bit with error correction')

error_context = contextlib.nullcontext()
if use_errors and version.parse(
torch.__version__) >= version.parse('2.1.0'):
error_context = pytest.raises(
RuntimeError, match='DecoupledLionW_8bit with error correction')

with error_context:
torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp
if not dist.is_initialized():
Expand Down Expand Up @@ -453,7 +455,7 @@ def _set_state_dict_type(model: nn.Module):
_LOCAL_STATE: LocalOptimStateDictConfig(),
}[state_sharding]
FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg,
optim_cfg)
optim_cfg)

# load FSDP state dict
_set_state_dict_type(mod)
Expand Down

0 comments on commit 9d211f6

Please sign in to comment.