diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 28084b7fb4..8e30554475 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -15,6 +15,8 @@ jobs: base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 - name: '2.0.1_cu118' base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 + - name: '2.1.0_cu121' + base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 steps: - name: Maximize Build Space on Worker diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 6af87346c8..efdf8eec58 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -27,6 +27,10 @@ jobs: container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04 markers: 'not gpu' pytest_command: 'coverage run -m pytest' + - name: 'cpu-2.1.0' + container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04 + markers: 'not gpu' + pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index d228802ddc..769b345e39 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -24,7 +24,11 @@ jobs: markers: 'gpu' pytest_command: 'coverage run -m pytest' - name: 'gpu-2.0.1' - container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 + container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 + markers: 'gpu' + pytest_command: 'coverage run -m pytest' + - name: 'gpu-2.1.0' + container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 markers: 'gpu' pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 806dbdbd14..f09d4a86c5 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch +from packaging import version class DecoupledLionW_8bit(torch.optim.Optimizer): @@ -53,7 +54,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): by retaining information across optimizer steps. Raises: - NotImplemenetedError - If any of `quantize`, `compress_state_dict`, + NotImplementedError - If any of `quantize`, `compress_state_dict`, or `error_correction` are `True` and either a) there is no CUDA device, or b) step() is executed on a non-CUDA parameter. """ @@ -67,6 +68,12 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... + if version.parse(torch.__version__) >= version.parse( + '2.1.0') and error_correction: + raise RuntimeError( + 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0' + ) + if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if not 0.0 <= betas[0] <= 1.0: diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index dbd6ff6352..35368be593 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import os import time import warnings @@ -41,40 +42,47 @@ (True, True)]) def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: - device = 'cuda' - torch.manual_seed(123) - X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) - W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) - W_orig = W.detach().clone() - - opt = Lion8bit([W], - lr=1.0, - _fused=fused, - betas=(.75, .75), - weight_decay=.2, - error_correction=use_errors) - - Y = X @ W - loss = Y.sum() - loss.backward() - torch.testing.assert_close(W_orig, W) # no weight modification yet - opt.step() - opt.zero_grad() - - with pytest.raises(AssertionError): # opt step modified the weights - torch.testing.assert_close(W_orig, W) - - # Every momentum should be nonzero with infinite precision, but - # might be zero after quantization. We turn the _MaybeQuantizedTensor - # instance into a regular torch Tensor to simplify this check. - param_state = opt.state[W] # type:ignore using tensor as key - momentum = param_state['exp_avg'].materialize() - assert momentum.shape == (D, D) - momentum = momentum.ravel() - if momentum.numel() == 1: - assert momentum.item() != 0 - else: - assert torch.std(momentum).item() > 0 + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + device = 'cuda' + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) + + Y = X @ W + loss = Y.sum() + loss.backward() + torch.testing.assert_close(W_orig, W) # no weight modification yet + opt.step() + opt.zero_grad() + + with pytest.raises(AssertionError): # opt step modified the weights + torch.testing.assert_close(W_orig, W) + + # Every momentum should be nonzero with infinite precision, but + # might be zero after quantization. We turn the _MaybeQuantizedTensor + # instance into a regular torch Tensor to simplify this check. + param_state = opt.state[W] # type:ignore using tensor as key + momentum = param_state['exp_avg'].materialize() + assert momentum.shape == (D, D) + momentum = momentum.ravel() + if momentum.numel() == 1: + assert momentum.item() != 0 + else: + assert torch.std(momentum).item() > 0 @pytest.mark.gpu @@ -92,32 +100,39 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, if (device == 'cpu') and (fused or use_errors): return - torch.manual_seed(123) - W = torch.rand((D, D), device=device, requires_grad=True) - W_orig = W.detach().clone() - - opt = Lion8bit([W], - _fused=fused, - betas=(.5, .5), - quantize=(device != 'cpu'), - weight_decay=weight_decay, - error_correction=use_errors) - - zeros_grad = torch.zeros_like(W) - for _ in range(5): - W.grad = zeros_grad - opt.step() - opt.zero_grad() + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.manual_seed(123) + W = torch.rand((D, D), device=device, requires_grad=True) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) - mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key - assert torch.all(mom.materialize() == 0) - if mom.is_quantized(): - assert torch.all(mom.quantized == 0) + zeros_grad = torch.zeros_like(W) + for _ in range(5): + W.grad = zeros_grad + opt.step() + opt.zero_grad() - if weight_decay: - assert torch.all(W_orig.abs() > W.abs()) - else: - torch.testing.assert_close(W_orig, W) # no weight modification + mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key + assert torch.all(mom.materialize() == 0) + if mom.is_quantized(): + assert torch.all(mom.quantized == 0) + + if weight_decay: + assert torch.all(W_orig.abs() > W.abs()) + else: + torch.testing.assert_close(W_orig, W) # no weight modification @pytest.mark.gpu @@ -132,43 +147,51 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return - torch.manual_seed(123) - X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) - W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) - - # we use tiny beta1 so we move almost entirely in the gradient direction - opt = Lion8bit([W], - lr=1e-2, - betas=(.5, .5), - quantize=(device != 'cpu'), - _fused=fused, - error_correction=use_errors) - - prev_loss = np.inf - prev_momentum = None - num_iters = 10 if device == 'cuda' else 2 # keep test fast - for _ in range(num_iters): - Y = X @ W - loss = (Y * Y).mean() - loss.backward() - opt.step() - opt.zero_grad() - loss_val = loss.item() - assert loss_val < prev_loss - prev_loss = loss_val + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + + # we use tiny beta1 so we move almost entirely in the gradient direction + opt = Lion8bit([W], + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) + + prev_loss = np.inf + prev_momentum = None + num_iters = 10 if device == 'cuda' else 2 # keep test fast + for _ in range(num_iters): + Y = X @ W + loss = (Y * Y).mean() + loss.backward() + opt.step() + opt.zero_grad() + + loss_val = loss.item() + assert loss_val < prev_loss + prev_loss = loss_val - # since we're getting the same batch every time and have a small - # learning rate, our gradients should point in the same direction - # at each step. Consequently, our momentum should grow each step. - state_for_param = opt.state[W] # type:ignore using tensor as key - momentum = state_for_param['exp_avg'].materialize() - assert momentum is not None and momentum.shape == W.shape - if prev_momentum is not None: - momentum_abs_changes = (momentum - prev_momentum).abs() - assert torch.all(momentum_abs_changes >= 0) - assert momentum_abs_changes.max() > 0 - prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place + # since we're getting the same batch every time and have a small + # learning rate, our gradients should point in the same direction + # at each step. Consequently, our momentum should grow each step. + state_for_param = opt.state[W] # type:ignore using tensor as key + momentum = state_for_param['exp_avg'].materialize() + assert momentum is not None and momentum.shape == W.shape + if prev_momentum is not None: + momentum_abs_changes = (momentum - prev_momentum).abs() + assert torch.all(momentum_abs_changes >= 0) + assert momentum_abs_changes.max() > 0 + prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place def _nmse(vals_true: torch.Tensor, @@ -232,11 +255,19 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, opt_uq = Lion8bit([W_uq], quantize=False, **kwargs) opt_uf = Lion8bit([W_uf], _fused=False, **kwargs) opt_fq = Lion8bit([W_fq], _fused=True, **kwargs) - opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) opt_sgd = torch.optim.SGD([W_sgd], lr=lr) - W_list = [W_true, W_uq, W_uf, W_fq, W_fqe, W_sgd] - opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_fqe, opt_sgd] + W_list = [W_true, W_uq, W_uf, W_fq, W_sgd] + opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_sgd] + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + opt_fqe = Lion8bit([W_fqe], + _fused=True, + error_correction=True, + **kwargs) + W_list.append(W_fqe) + opt_list.append(opt_fqe) if grad_strategy == 'zero': grads = torch.zeros_like(W0) @@ -301,12 +332,14 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim assert _nmse(diffs_true, diffs_fq) < max_nmse - # fused impl with errors should also be close to "true" updates; - assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim - assert _nmse(diffs_true, diffs_fqe) < max_nmse + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + # fused impl with errors should also be close to "true" updates; + assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fqe) < max_nmse - # error correction should reduce error, or at least do no worse - assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) + # error correction should reduce error, or at least do no worse + assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) # if sgd weights aren't different than LION weights, we haven't # changed them enough to meaningfully test the LION logic @@ -328,58 +361,68 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, @pytest.mark.parametrize('use_errors', [False, True]) def test_state_dict_save_load(device: str, quantized_state: bool, dtype: torch.dtype, use_errors: bool): - torch.manual_seed(123) - params = [] - for shape in _MANY_PARAM_SHAPES: - p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) - p.grad = torch.rand_like(p) - params.append(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) - if device == 'cpu': - with pytest.raises(NotImplementedError): - opt.step() - return - else: - opt.step() - opt.zero_grad() - - # copy state dict into a new instance - state_dict = opt.state_dict() - opt_new = Lion8bit(params, + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.manual_seed(123) + params = [] + for shape in _MANY_PARAM_SHAPES: + p = torch.rand(shape, + device=device, + dtype=dtype, + requires_grad=True) + p.grad = torch.rand_like(p) + params.append(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(params, compress_state_dict=quantized_state, error_correction=use_errors) - opt_new.load_state_dict(state_dict) - - for p in params: - d_orig = opt.state[p] - d_new = opt_new.state[p] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] - if quantized_state: - # Optimizer load_state_dict insists on converting scales to - # dtype of param, which is lossy for bf16 params. - # Ideally we'd require == for everything but it's less complexity - # to just relax the bf16 test - assert torch.all(mom_orig.quantized == mom_new.quantized) - if dtype == torch.bfloat16: - torch.testing.assert_close(mom_orig.scales, - mom_new.scales, - atol=1e-3, - rtol=1e-2) - else: - assert torch.all(mom_orig.scales == mom_new.scales) + if device == 'cpu': + with pytest.raises(NotImplementedError): + opt.step() + return + else: + opt.step() + opt.zero_grad() - torch.testing.assert_close(mom_orig.materialize(), - mom_new.materialize(), - atol=1. / (2 * 127), - rtol=np.inf) - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + # copy state dict into a new instance + state_dict = opt.state_dict() + opt_new = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + opt_new.load_state_dict(state_dict) + + for p in params: + d_orig = opt.state[p] + d_new = opt_new.state[p] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + if quantized_state: + # Optimizer load_state_dict insists on converting scales to + # dtype of param, which is lossy for bf16 params. + # Ideally we'd require == for everything but it's less complexity + # to just relax the bf16 test + assert torch.all(mom_orig.quantized == mom_new.quantized) + if dtype == torch.bfloat16: + torch.testing.assert_close(mom_orig.scales, + mom_new.scales, + atol=1e-3, + rtol=1e-2) + else: + assert torch.all(mom_orig.scales == mom_new.scales) + + torch.testing.assert_close(mom_orig.materialize(), + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) class _DummyModule(nn.Module): @@ -414,88 +457,95 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, if version.parse(torch.__version__) < version.parse('2.0.1'): pytest.skip(f'This test requires torch 2.0.1 or greater.') - torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp - if not dist.is_initialized(): - dist.init_process_group() - assert dist.get_world_size() >= 2, 'Misconfigured test run!' - - mod = FSDP(_DummyModule(device=device, dtype=dtype)) - - # actual forward pass instead of setting p.grad to avoid FSDP issues - X = torch.rand(size=(5, 4), device=device, dtype=dtype) - Y = mod(X) - Y.sum().backward() - for p in mod.parameters(): - p.grad = torch.rand_like(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(mod.parameters(), error_correction=use_errors) - opt.step() - opt.zero_grad() - - def _set_state_dict_type(model: nn.Module): - # for mapping between state dict types and optim state dict types, see: - # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa - state_dict_cfg = { - _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), - _SHARDED_STATE: fsdp.ShardedStateDictConfig(), - _LOCAL_STATE: fsdp.LocalStateDictConfig(), - }[state_sharding] - optim_cfg = { - _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), - _SHARDED_STATE: ShardedOptimStateDictConfig(), - _LOCAL_STATE: LocalOptimStateDictConfig(), - }[state_sharding] - FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, - optim_cfg) - - # load FSDP state dict - _set_state_dict_type(mod) - opt_state_dict = FSDP.optim_state_dict(mod, opt) - - # make a new model and optimizer - mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) - opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) - _set_state_dict_type(mod_new) - - # load state dict into the new optimizer - opt_state_dict_slice = FSDP.optim_state_dict_to_load( - opt_state_dict, mod_new, opt_new) - opt_new.load_state_dict(opt_state_dict_slice) - - new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) - - orig_state = opt_state_dict['state'] - orig_param_groups = opt_state_dict['param_groups'] - new_state = new_opt_state_dict['state'] - new_param_groups = new_opt_state_dict['param_groups'] - - all_keys = set(orig_state.keys()) | set(new_state.keys()) - assert orig_param_groups == new_param_groups # works since strs, not ptrs - for k in all_keys: # keys are param paths in module as strings - d_orig = orig_state[k] - d_new = new_state[k] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] - - assert mom_orig.shape == mom_new.shape - assert mom_orig.dtype == mom_new.dtype - if use_errors: - errs_orig = d_orig['errors'] - errs_new = d_new['errors'] - assert errs_orig.shape == errs_new.shape - assert errs_orig.dtype == errs_new.dtype - - if state_sharding != _FULL_STATE: - continue # more detailed checks lean on FSDP impl details - - # momentums may not be bit-for-bit identical because Optimizer upcasts - # to f32 and we convert back to bf16, possibly with different rounding - torch.testing.assert_close(mom_orig, mom_new) - # errors not bit-for-bit identical because scales get upcast too - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + assert dist.get_world_size() >= 2, 'Misconfigured test run!' + + mod = FSDP(_DummyModule(device=device, dtype=dtype)) + + # actual forward pass instead of setting p.grad to avoid FSDP issues + X = torch.rand(size=(5, 4), device=device, dtype=dtype) + Y = mod(X) + Y.sum().backward() + for p in mod.parameters(): + p.grad = torch.rand_like(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + opt.step() + opt.zero_grad() + + def _set_state_dict_type(model: nn.Module): + # for mapping between state dict types and optim state dict types, see: + # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa + state_dict_cfg = { + _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.LocalStateDictConfig(), + }[state_sharding] + optim_cfg = { + _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: ShardedOptimStateDictConfig(), + _LOCAL_STATE: LocalOptimStateDictConfig(), + }[state_sharding] + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, + optim_cfg) + + # load FSDP state dict + _set_state_dict_type(mod) + opt_state_dict = FSDP.optim_state_dict(mod, opt) + + # make a new model and optimizer + mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + _set_state_dict_type(mod_new) + + # load state dict into the new optimizer + opt_state_dict_slice = FSDP.optim_state_dict_to_load( + optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) + opt_new.load_state_dict(opt_state_dict_slice) + + new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) + + orig_state = opt_state_dict['state'] + orig_param_groups = opt_state_dict['param_groups'] + new_state = new_opt_state_dict['state'] + new_param_groups = new_opt_state_dict['param_groups'] + + all_keys = set(orig_state.keys()) | set(new_state.keys()) + assert orig_param_groups == new_param_groups # works since strs, not ptrs + for k in all_keys: # keys are param paths in module as strings + d_orig = orig_state[k] + d_new = new_state[k] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + + assert mom_orig.shape == mom_new.shape + assert mom_orig.dtype == mom_new.dtype + if use_errors and (dtype != torch.float32): + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype + + if state_sharding != _FULL_STATE: + continue # more detailed checks lean on FSDP impl details + + # momentums may not be bit-for-bit identical because Optimizer upcasts + # to f32 and we convert back to bf16, possibly with different rounding + torch.testing.assert_close(mom_orig, mom_new) + # errors not bit-for-bit identical because scales get upcast too + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) @pytest.mark.gpu @@ -515,7 +565,12 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): times = {} kwargs = {'weight_decay': .01} + combos = [(True, False), (True, True), (False, False), ('NA', False)] + # use_errors not currently supported on torch 2.1 + if version.parse(torch.__version__) >= version.parse('2.1.0'): + combos = [(True, False), (False, False), ('NA', False)] + for fused, use_errors in combos: if fused == 'NA': opt = Lion8bit( @@ -548,12 +603,10 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): try: assert times[True] < times[False] + atol assert times[True] < times['NA'] + atol - assert times['ecc'] < times['NA'] + atol - print('') - print('time fused (ms): ', times[True] * 1e3) - print('time fused+ecc (ms): ', times['ecc'] * 1e3) - print('time unfused (ms): ', times[False] * 1e3) - print('time unquantized (ms): ', times['NA'] * 1e3) + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + assert times['ecc'] < times['NA'] + atol break except AssertionError as e: if it >= 2: # allow 3 retries to avoid flakiness