diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index f09d4a86c5..2c2e6e2d35 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch -from packaging import version class DecoupledLionW_8bit(torch.optim.Optimizer): @@ -68,11 +67,6 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... - if version.parse(torch.__version__) >= version.parse( - '2.1.0') and error_correction: - raise RuntimeError( - 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0' - ) if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -138,11 +132,19 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: mom, try_quantize=self._quantize) need_errs = (p.dtype != torch.float32) and self._error_correction if state.get('errors') is None and need_errs: - state['errors'] = torch.zeros(p.shape, - dtype=torch.uint8, - device=p.device) + numel = p.numel() + numel += numel % 2 # ensure even number of bytes + errors = torch.zeros(numel, dtype=torch.uint8, device=p.device) + # as of torch 2.1, FSDP can't shard ints for no reason + state['errors'] = errors.view(torch.bfloat16) decay_factor = hparams['weight_decay'] decay_factor *= hparams['lr'] / hparams['initial_lr'] + errors: Optional[torch.Tensor] = None + if 'errors' in state: + errors = state['errors'] + assert errors is not None # pyright + errors = errors.view(dtype=torch.uint8) + errors = errors[:p.numel()].view(p.shape) # strip padding + reshape _lion8b_step(momentums=state['exp_avg'], weights=p, grads=p.grad, @@ -151,7 +153,7 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: lr=hparams['lr'], weight_decay=decay_factor, fused=hparams['fused'], - errors=state.get('errors')) + errors=errors) def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: # we override this function to quantize optimizer states when @@ -173,7 +175,8 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: # we need to cast back to the correct dtype since optimizer # load_state_dict casts to param dtype for fp params; see # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa - errs = param_state['errors'].to(dtype=torch.uint8) + errs = param_state['errors'].to(dtype=torch.uint8).view( + torch.bfloat16) new_state['errors'] = errs opt_state[param_id] = new_state super().__setstate__(state) @@ -199,6 +202,11 @@ def state_dict(self): qtensor.state_dict( name='exp_avg', allow_quantized=self._compress_state_dict)) + if 'errors' in param_state: + # fsdp apparently needs the states to be the same shape + # as the params + param_state['errors'] = param_state['errors'].view( + torch.uint8).to(dtype=torch.bfloat16) opt_state[param_id] = param_state return d diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 35368be593..ddb70e882b 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import contextlib import os import time import warnings @@ -42,47 +41,40 @@ (True, True)]) def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: - error_context = contextlib.nullcontext() - if use_errors and version.parse( - torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises( - RuntimeError, match='DecoupledLionW_8bit with error correction') - - with error_context: - device = 'cuda' - torch.manual_seed(123) - X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) - W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) - W_orig = W.detach().clone() - - opt = Lion8bit([W], - lr=1.0, - _fused=fused, - betas=(.75, .75), - weight_decay=.2, - error_correction=use_errors) - - Y = X @ W - loss = Y.sum() - loss.backward() - torch.testing.assert_close(W_orig, W) # no weight modification yet - opt.step() - opt.zero_grad() - - with pytest.raises(AssertionError): # opt step modified the weights - torch.testing.assert_close(W_orig, W) - - # Every momentum should be nonzero with infinite precision, but - # might be zero after quantization. We turn the _MaybeQuantizedTensor - # instance into a regular torch Tensor to simplify this check. - param_state = opt.state[W] # type:ignore using tensor as key - momentum = param_state['exp_avg'].materialize() - assert momentum.shape == (D, D) - momentum = momentum.ravel() - if momentum.numel() == 1: - assert momentum.item() != 0 - else: - assert torch.std(momentum).item() > 0 + device = 'cuda' + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) + + Y = X @ W + loss = Y.sum() + loss.backward() + torch.testing.assert_close(W_orig, W) # no weight modification yet + opt.step() + opt.zero_grad() + + with pytest.raises(AssertionError): # opt step modified the weights + torch.testing.assert_close(W_orig, W) + + # Every momentum should be nonzero with infinite precision, but + # might be zero after quantization. We turn the _MaybeQuantizedTensor + # instance into a regular torch Tensor to simplify this check. + param_state = opt.state[W] # type:ignore using tensor as key + momentum = param_state['exp_avg'].materialize() + assert momentum.shape == (D, D) + momentum = momentum.ravel() + if momentum.numel() == 1: + assert momentum.item() != 0 + else: + assert torch.std(momentum).item() > 0 @pytest.mark.gpu @@ -100,39 +92,32 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, if (device == 'cpu') and (fused or use_errors): return - error_context = contextlib.nullcontext() - if use_errors and version.parse( - torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises( - RuntimeError, match='DecoupledLionW_8bit with error correction') - - with error_context: - torch.manual_seed(123) - W = torch.rand((D, D), device=device, requires_grad=True) - W_orig = W.detach().clone() - - opt = Lion8bit([W], - _fused=fused, - betas=(.5, .5), - quantize=(device != 'cpu'), - weight_decay=weight_decay, - error_correction=use_errors) - - zeros_grad = torch.zeros_like(W) - for _ in range(5): - W.grad = zeros_grad - opt.step() - opt.zero_grad() + torch.manual_seed(123) + W = torch.rand((D, D), device=device, requires_grad=True) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) + + zeros_grad = torch.zeros_like(W) + for _ in range(5): + W.grad = zeros_grad + opt.step() + opt.zero_grad() - mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key - assert torch.all(mom.materialize() == 0) - if mom.is_quantized(): - assert torch.all(mom.quantized == 0) + mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key + assert torch.all(mom.materialize() == 0) + if mom.is_quantized(): + assert torch.all(mom.quantized == 0) - if weight_decay: - assert torch.all(W_orig.abs() > W.abs()) - else: - torch.testing.assert_close(W_orig, W) # no weight modification + if weight_decay: + assert torch.all(W_orig.abs() > W.abs()) + else: + torch.testing.assert_close(W_orig, W) # no weight modification @pytest.mark.gpu @@ -147,51 +132,43 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + + # we use tiny beta1 so we move almost entirely in the gradient direction + opt = Lion8bit([W], + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) + + prev_loss = np.inf + prev_momentum = None + num_iters = 10 if device == 'cuda' else 2 # keep test fast + for _ in range(num_iters): + Y = X @ W + loss = (Y * Y).mean() + loss.backward() + opt.step() + opt.zero_grad() - error_context = contextlib.nullcontext() - if use_errors and version.parse( - torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises( - RuntimeError, match='DecoupledLionW_8bit with error correction') - - with error_context: - torch.manual_seed(123) - X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) - W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) - - # we use tiny beta1 so we move almost entirely in the gradient direction - opt = Lion8bit([W], - lr=1e-2, - betas=(.5, .5), - quantize=(device != 'cpu'), - _fused=fused, - error_correction=use_errors) - - prev_loss = np.inf - prev_momentum = None - num_iters = 10 if device == 'cuda' else 2 # keep test fast - for _ in range(num_iters): - Y = X @ W - loss = (Y * Y).mean() - loss.backward() - opt.step() - opt.zero_grad() - - loss_val = loss.item() - assert loss_val < prev_loss - prev_loss = loss_val + loss_val = loss.item() + assert loss_val < prev_loss + prev_loss = loss_val - # since we're getting the same batch every time and have a small - # learning rate, our gradients should point in the same direction - # at each step. Consequently, our momentum should grow each step. - state_for_param = opt.state[W] # type:ignore using tensor as key - momentum = state_for_param['exp_avg'].materialize() - assert momentum is not None and momentum.shape == W.shape - if prev_momentum is not None: - momentum_abs_changes = (momentum - prev_momentum).abs() - assert torch.all(momentum_abs_changes >= 0) - assert momentum_abs_changes.max() > 0 - prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place + # since we're getting the same batch every time and have a small + # learning rate, our gradients should point in the same direction + # at each step. Consequently, our momentum should grow each step. + state_for_param = opt.state[W] # type:ignore using tensor as key + momentum = state_for_param['exp_avg'].materialize() + assert momentum is not None and momentum.shape == W.shape + if prev_momentum is not None: + momentum_abs_changes = (momentum - prev_momentum).abs() + assert torch.all(momentum_abs_changes >= 0) + assert momentum_abs_changes.max() > 0 + prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place def _nmse(vals_true: torch.Tensor, @@ -255,19 +232,11 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, opt_uq = Lion8bit([W_uq], quantize=False, **kwargs) opt_uf = Lion8bit([W_uf], _fused=False, **kwargs) opt_fq = Lion8bit([W_fq], _fused=True, **kwargs) + opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) opt_sgd = torch.optim.SGD([W_sgd], lr=lr) - W_list = [W_true, W_uq, W_uf, W_fq, W_sgd] - opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_sgd] - - # error correction not supported on torch 2.1 - if version.parse(torch.__version__) < version.parse('2.1.0'): - opt_fqe = Lion8bit([W_fqe], - _fused=True, - error_correction=True, - **kwargs) - W_list.append(W_fqe) - opt_list.append(opt_fqe) + W_list = [W_true, W_uq, W_uf, W_fq, W_fqe, W_sgd] + opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_fqe, opt_sgd] if grad_strategy == 'zero': grads = torch.zeros_like(W0) @@ -332,14 +301,12 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim assert _nmse(diffs_true, diffs_fq) < max_nmse - # error correction not supported on torch 2.1 - if version.parse(torch.__version__) < version.parse('2.1.0'): - # fused impl with errors should also be close to "true" updates; - assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim - assert _nmse(diffs_true, diffs_fqe) < max_nmse + # fused impl with errors should also be close to "true" updates; + assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fqe) < max_nmse - # error correction should reduce error, or at least do no worse - assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) + # error correction should reduce error, or at least do no worse + assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) # if sgd weights aren't different than LION weights, we haven't # changed them enough to meaningfully test the LION logic @@ -361,68 +328,58 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, @pytest.mark.parametrize('use_errors', [False, True]) def test_state_dict_save_load(device: str, quantized_state: bool, dtype: torch.dtype, use_errors: bool): - error_context = contextlib.nullcontext() - if use_errors and version.parse( - torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises( - RuntimeError, match='DecoupledLionW_8bit with error correction') - - with error_context: - torch.manual_seed(123) - params = [] - for shape in _MANY_PARAM_SHAPES: - p = torch.rand(shape, - device=device, - dtype=dtype, - requires_grad=True) - p.grad = torch.rand_like(p) - params.append(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(params, + torch.manual_seed(123) + params = [] + for shape in _MANY_PARAM_SHAPES: + p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) + p.grad = torch.rand_like(p) + params.append(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + if device == 'cpu': + with pytest.raises(NotImplementedError): + opt.step() + return + else: + opt.step() + opt.zero_grad() + + # copy state dict into a new instance + state_dict = opt.state_dict() + opt_new = Lion8bit(params, compress_state_dict=quantized_state, error_correction=use_errors) - if device == 'cpu': - with pytest.raises(NotImplementedError): - opt.step() - return - else: - opt.step() - opt.zero_grad() + opt_new.load_state_dict(state_dict) + + for p in params: + d_orig = opt.state[p] + d_new = opt_new.state[p] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + if quantized_state: + # Optimizer load_state_dict insists on converting scales to + # dtype of param, which is lossy for bf16 params. + # Ideally we'd require == for everything but it's less complexity + # to just relax the bf16 test + assert torch.all(mom_orig.quantized == mom_new.quantized) + if dtype == torch.bfloat16: + torch.testing.assert_close(mom_orig.scales, + mom_new.scales, + atol=1e-3, + rtol=1e-2) + else: + assert torch.all(mom_orig.scales == mom_new.scales) - # copy state dict into a new instance - state_dict = opt.state_dict() - opt_new = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) - opt_new.load_state_dict(state_dict) - - for p in params: - d_orig = opt.state[p] - d_new = opt_new.state[p] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] - if quantized_state: - # Optimizer load_state_dict insists on converting scales to - # dtype of param, which is lossy for bf16 params. - # Ideally we'd require == for everything but it's less complexity - # to just relax the bf16 test - assert torch.all(mom_orig.quantized == mom_new.quantized) - if dtype == torch.bfloat16: - torch.testing.assert_close(mom_orig.scales, - mom_new.scales, - atol=1e-3, - rtol=1e-2) - else: - assert torch.all(mom_orig.scales == mom_new.scales) - - torch.testing.assert_close(mom_orig.materialize(), - mom_new.materialize(), - atol=1. / (2 * 127), - rtol=np.inf) - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + torch.testing.assert_close(mom_orig.materialize(), + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) class _DummyModule(nn.Module): @@ -430,7 +387,7 @@ class _DummyModule(nn.Module): def __init__(self, device: str, dtype: torch.dtype): super().__init__() self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) - self.linear1 = nn.Linear(3, 4, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore return self.linear1(self.linear0(x)) @@ -457,95 +414,88 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, if version.parse(torch.__version__) < version.parse('2.0.1'): pytest.skip(f'This test requires torch 2.0.1 or greater.') - error_context = contextlib.nullcontext() - if use_errors and version.parse( - torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises( - RuntimeError, match='DecoupledLionW_8bit with error correction') - - with error_context: - torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp - if not dist.is_initialized(): - dist.init_process_group(backend='nccl') - assert dist.get_world_size() >= 2, 'Misconfigured test run!' - - mod = FSDP(_DummyModule(device=device, dtype=dtype)) - - # actual forward pass instead of setting p.grad to avoid FSDP issues - X = torch.rand(size=(5, 4), device=device, dtype=dtype) - Y = mod(X) - Y.sum().backward() - for p in mod.parameters(): - p.grad = torch.rand_like(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(mod.parameters(), error_correction=use_errors) - opt.step() - opt.zero_grad() - - def _set_state_dict_type(model: nn.Module): - # for mapping between state dict types and optim state dict types, see: - # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa - state_dict_cfg = { - _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), - _SHARDED_STATE: fsdp.ShardedStateDictConfig(), - _LOCAL_STATE: fsdp.LocalStateDictConfig(), - }[state_sharding] - optim_cfg = { - _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), - _SHARDED_STATE: ShardedOptimStateDictConfig(), - _LOCAL_STATE: LocalOptimStateDictConfig(), - }[state_sharding] - FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, - optim_cfg) - - # load FSDP state dict - _set_state_dict_type(mod) - opt_state_dict = FSDP.optim_state_dict(mod, opt) - - # make a new model and optimizer - mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) - opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) - _set_state_dict_type(mod_new) - - # load state dict into the new optimizer - opt_state_dict_slice = FSDP.optim_state_dict_to_load( - optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) - opt_new.load_state_dict(opt_state_dict_slice) - - new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) - - orig_state = opt_state_dict['state'] - orig_param_groups = opt_state_dict['param_groups'] - new_state = new_opt_state_dict['state'] - new_param_groups = new_opt_state_dict['param_groups'] - - all_keys = set(orig_state.keys()) | set(new_state.keys()) - assert orig_param_groups == new_param_groups # works since strs, not ptrs - for k in all_keys: # keys are param paths in module as strings - d_orig = orig_state[k] - d_new = new_state[k] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] - - assert mom_orig.shape == mom_new.shape - assert mom_orig.dtype == mom_new.dtype - if use_errors and (dtype != torch.float32): - errs_orig = d_orig['errors'] - errs_new = d_new['errors'] - assert errs_orig.shape == errs_new.shape - assert errs_orig.dtype == errs_new.dtype - - if state_sharding != _FULL_STATE: - continue # more detailed checks lean on FSDP impl details - - # momentums may not be bit-for-bit identical because Optimizer upcasts - # to f32 and we convert back to bf16, possibly with different rounding - torch.testing.assert_close(mom_orig, mom_new) - # errors not bit-for-bit identical because scales get upcast too - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + assert dist.get_world_size() >= 2, 'Misconfigured test run!' + + mod = FSDP(_DummyModule(device=device, dtype=dtype)) + + # actual forward pass instead of setting p.grad to avoid FSDP issues + X = torch.rand(size=(5, 4), device=device, dtype=dtype) + Y = mod(X) + Y.sum().backward() + for p in mod.parameters(): + p.grad = torch.rand_like(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + opt.step() + opt.zero_grad() + + def _set_state_dict_type(model: nn.Module): + # for mapping between state dict types and optim state dict types, see: + # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa + state_dict_cfg = { + _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.LocalStateDictConfig(), + }[state_sharding] + optim_cfg = { + _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: ShardedOptimStateDictConfig(), + _LOCAL_STATE: LocalOptimStateDictConfig(), + }[state_sharding] + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, + optim_cfg) + + # load FSDP state dict + _set_state_dict_type(mod) + opt_state_dict = FSDP.optim_state_dict(mod, opt) + + # make a new model and optimizer + mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + _set_state_dict_type(mod_new) + + # load state dict into the new optimizer + opt_state_dict_slice = FSDP.optim_state_dict_to_load( + optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) + opt_new.load_state_dict(opt_state_dict_slice) + + new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) + + orig_state = opt_state_dict['state'] + orig_param_groups = opt_state_dict['param_groups'] + new_state = new_opt_state_dict['state'] + new_param_groups = new_opt_state_dict['param_groups'] + + all_keys = set(orig_state.keys()) | set(new_state.keys()) + assert orig_param_groups == new_param_groups # works since strs, not ptrs + for k in all_keys: # keys are param paths in module as strings + d_orig = orig_state[k] + d_new = new_state[k] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + + assert mom_orig.shape == mom_new.shape + assert mom_orig.dtype == mom_new.dtype + if use_errors and (dtype != torch.float32): + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype + + if state_sharding != _FULL_STATE: + continue # more detailed checks lean on FSDP impl details + + # momentums may not be bit-for-bit identical because Optimizer upcasts + # to f32 and we convert back to bf16, possibly with different rounding + torch.testing.assert_close(mom_orig, mom_new) + # errors not bit-for-bit identical because scales get upcast too + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) @pytest.mark.gpu @@ -565,12 +515,7 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): times = {} kwargs = {'weight_decay': .01} - combos = [(True, False), (True, True), (False, False), ('NA', False)] - # use_errors not currently supported on torch 2.1 - if version.parse(torch.__version__) >= version.parse('2.1.0'): - combos = [(True, False), (False, False), ('NA', False)] - for fused, use_errors in combos: if fused == 'NA': opt = Lion8bit( @@ -603,10 +548,12 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): try: assert times[True] < times[False] + atol assert times[True] < times['NA'] + atol - - # error correction not supported on torch 2.1 - if version.parse(torch.__version__) < version.parse('2.1.0'): - assert times['ecc'] < times['NA'] + atol + assert times['ecc'] < times['NA'] + atol + print('') + print('time fused (ms): ', times[True] * 1e3) + print('time fused+ecc (ms): ', times['ecc'] * 1e3) + print('time unfused (ms): ', times[False] * 1e3) + print('time unquantized (ms): ', times['NA'] * 1e3) break except AssertionError as e: if it >= 2: # allow 3 retries to avoid flakiness