diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 7e2a03a59b..f09d4a86c5 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -68,9 +68,11 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... - if version.parse(torch.__version__) >= version.parse('2.1.0') and error_correction: + if version.parse(torch.__version__) >= version.parse( + '2.1.0') and error_correction: raise RuntimeError( - 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0') + 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0' + ) if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index cfc5970663..7e4d97390c 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -1,10 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import os import time import warnings -import contextlib import numpy as np import packaging.version as version @@ -415,10 +415,12 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, if version.parse(torch.__version__) < version.parse('2.0.1'): pytest.skip(f'This test requires torch 2.0.1 or greater.') - error_context = contextlib.null_context() - if use_errors and version.parse(torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises(RuntimeError, match='DecoupledLionW_8bit with error correction') - + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + with error_context: torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not dist.is_initialized(): @@ -453,7 +455,7 @@ def _set_state_dict_type(model: nn.Module): _LOCAL_STATE: LocalOptimStateDictConfig(), }[state_sharding] FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, - optim_cfg) + optim_cfg) # load FSDP state dict _set_state_dict_type(mod)