diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 806dbdbd14..7e2a03a59b 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch +from packaging import version class DecoupledLionW_8bit(torch.optim.Optimizer): @@ -53,7 +54,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): by retaining information across optimizer steps. Raises: - NotImplemenetedError - If any of `quantize`, `compress_state_dict`, + NotImplementedError - If any of `quantize`, `compress_state_dict`, or `error_correction` are `True` and either a) there is no CUDA device, or b) step() is executed on a non-CUDA parameter. """ @@ -67,6 +68,10 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... + if version.parse(torch.__version__) >= version.parse('2.1.0') and error_correction: + raise RuntimeError( + 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0') + if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if not 0.0 <= betas[0] <= 1.0: diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 088a6ab909..cfc5970663 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -4,6 +4,7 @@ import os import time import warnings +import contextlib import numpy as np import packaging.version as version @@ -414,88 +415,93 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, if version.parse(torch.__version__) < version.parse('2.0.1'): pytest.skip(f'This test requires torch 2.0.1 or greater.') - torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp - if not dist.is_initialized(): - dist.init_process_group(backend='nccl') - assert dist.get_world_size() >= 2, 'Misconfigured test run!' - - mod = FSDP(_DummyModule(device=device, dtype=dtype)) - - # actual forward pass instead of setting p.grad to avoid FSDP issues - X = torch.rand(size=(5, 4), device=device, dtype=dtype) - Y = mod(X) - Y.sum().backward() - for p in mod.parameters(): - p.grad = torch.rand_like(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(mod.parameters(), error_correction=use_errors) - opt.step() - opt.zero_grad() - - def _set_state_dict_type(model: nn.Module): - # for mapping between state dict types and optim state dict types, see: - # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa - state_dict_cfg = { - _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), - _SHARDED_STATE: fsdp.ShardedStateDictConfig(), - _LOCAL_STATE: fsdp.LocalStateDictConfig(), - }[state_sharding] - optim_cfg = { - _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), - _SHARDED_STATE: ShardedOptimStateDictConfig(), - _LOCAL_STATE: LocalOptimStateDictConfig(), - }[state_sharding] - FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, - optim_cfg) - - # load FSDP state dict - _set_state_dict_type(mod) - opt_state_dict = FSDP.optim_state_dict(mod, opt) - - # make a new model and optimizer - mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) - opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) - _set_state_dict_type(mod_new) - - # load state dict into the new optimizer - opt_state_dict_slice = FSDP.optim_state_dict_to_load( - optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) - opt_new.load_state_dict(opt_state_dict_slice) - - new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) - - orig_state = opt_state_dict['state'] - orig_param_groups = opt_state_dict['param_groups'] - new_state = new_opt_state_dict['state'] - new_param_groups = new_opt_state_dict['param_groups'] - - all_keys = set(orig_state.keys()) | set(new_state.keys()) - assert orig_param_groups == new_param_groups # works since strs, not ptrs - for k in all_keys: # keys are param paths in module as strings - d_orig = orig_state[k] - d_new = new_state[k] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] + error_context = contextlib.null_context() + if use_errors and version.parse(torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises(RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + assert dist.get_world_size() >= 2, 'Misconfigured test run!' + + mod = FSDP(_DummyModule(device=device, dtype=dtype)) + + # actual forward pass instead of setting p.grad to avoid FSDP issues + X = torch.rand(size=(5, 4), device=device, dtype=dtype) + Y = mod(X) + Y.sum().backward() + for p in mod.parameters(): + p.grad = torch.rand_like(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + opt.step() + opt.zero_grad() - assert mom_orig.shape == mom_new.shape - assert mom_orig.dtype == mom_new.dtype - if use_errors and (dtype != torch.float32): - errs_orig = d_orig['errors'] - errs_new = d_new['errors'] - assert errs_orig.shape == errs_new.shape - assert errs_orig.dtype == errs_new.dtype - - if state_sharding != _FULL_STATE: - continue # more detailed checks lean on FSDP impl details - - # momentums may not be bit-for-bit identical because Optimizer upcasts - # to f32 and we convert back to bf16, possibly with different rounding - torch.testing.assert_close(mom_orig, mom_new) - # errors not bit-for-bit identical because scales get upcast too - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + def _set_state_dict_type(model: nn.Module): + # for mapping between state dict types and optim state dict types, see: + # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa + state_dict_cfg = { + _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.LocalStateDictConfig(), + }[state_sharding] + optim_cfg = { + _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: ShardedOptimStateDictConfig(), + _LOCAL_STATE: LocalOptimStateDictConfig(), + }[state_sharding] + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, + optim_cfg) + + # load FSDP state dict + _set_state_dict_type(mod) + opt_state_dict = FSDP.optim_state_dict(mod, opt) + + # make a new model and optimizer + mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + _set_state_dict_type(mod_new) + + # load state dict into the new optimizer + opt_state_dict_slice = FSDP.optim_state_dict_to_load( + optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) + opt_new.load_state_dict(opt_state_dict_slice) + + new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) + + orig_state = opt_state_dict['state'] + orig_param_groups = opt_state_dict['param_groups'] + new_state = new_opt_state_dict['state'] + new_param_groups = new_opt_state_dict['param_groups'] + + all_keys = set(orig_state.keys()) | set(new_state.keys()) + assert orig_param_groups == new_param_groups # works since strs, not ptrs + for k in all_keys: # keys are param paths in module as strings + d_orig = orig_state[k] + d_new = new_state[k] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + + assert mom_orig.shape == mom_new.shape + assert mom_orig.dtype == mom_new.dtype + if use_errors and (dtype != torch.float32): + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype + + if state_sharding != _FULL_STATE: + continue # more detailed checks lean on FSDP impl details + + # momentums may not be bit-for-bit identical because Optimizer upcasts + # to f32 and we convert back to bf16, possibly with different rounding + torch.testing.assert_close(mom_orig, mom_new) + # errors not bit-for-bit identical because scales get upcast too + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) @pytest.mark.gpu